import argparse import logging from toolchain.common.language_detector import LanguageDetector logger = logging.getLogger(__name__) class PostAlignmentCleaner: DEFAULT_LANGUAGE_DETECTION_THRESHOLD = 40 DEFAULT_REJECTED_LINE_DELIMITER = "@@@" REJECTION_EMPTY = "empty_segment" REJECTION_NONALPHA = "nonalpha" REJECTION_UNEXPECTED_LANGUAGE = "unexpected_language_[{0}:{1}]" def __init__(self, lang_src, lang_tgt, config={}, language_detector=LanguageDetector()): self.lang_src = lang_src self.lang_tgt = lang_tgt self.language_detector = language_detector self.language_detection_threshold = int(config.get("language_detection_threshold", self.DEFAULT_LANGUAGE_DETECTION_THRESHOLD)) rejected_line_delimiter = config.get("rejected_line_delimiter", self.DEFAULT_REJECTED_LINE_DELIMITER) self.rejected_line_template = rejected_line_delimiter.join(["{0}", "{1}", "{2}"]) def clean(self, input_path_src, input_path_tgt, output_path_src, output_path_tgt, output_path_rejected): logger.info("Cleaning {0} and {1} to {2} and {3} with rejections to {4}.".format(input_path_src, input_path_tgt, output_path_src, output_path_tgt, output_path_rejected)) with open(input_path_src) as input_src, open(input_path_tgt) as input_tgt,\ open(output_path_src, "w") as output_src, open(output_path_tgt, "w") as output_tgt, open(output_path_rejected, "w") as output_rejected: self.clean_text(input_src, input_tgt, output_src, output_tgt, output_rejected) def clean_text(self, input_src, input_tgt, output_src, output_tgt, output_rejected): for input_pair in zip(input_src, input_tgt): term_src, term_tgt = input_pair[0].rstrip("\n"), input_pair[1].rstrip("\n") should_include, message = self.should_include(term_src.strip(), term_tgt.strip()) if should_include: self.write_file_line(output_src, term_src) self.write_file_line(output_tgt, term_tgt) else: self.write_file_line(output_rejected, message) def should_include(self, term_src, term_tgt): if not term_src or not term_tgt: message = self.rejected_line_template.format(self.REJECTION_EMPTY, term_src, term_tgt) return False, message if not self.contains_alpha(term_src) and not self.contains_alpha(term_tgt): message = self.rejected_line_template.format(self.REJECTION_NONALPHA, term_src, term_tgt) return False, message if len(term_src) >= self.language_detection_threshold or len(term_tgt) >= self.language_detection_threshold or term_src == term_tgt: detected_lang_src = self.language_detector.detect(term_src) detected_lang_tgt = self.language_detector.detect(term_tgt) if detected_lang_src != self.lang_src or detected_lang_tgt != self.lang_tgt: reason = self.REJECTION_UNEXPECTED_LANGUAGE.format(detected_lang_src, detected_lang_tgt) message = self.rejected_line_template.format(reason, term_src, term_tgt) return False, message return True, "" def contains_alpha(self, token): return any(c.isalpha() for c in token) def write_file_line(self, file, text): file.write(text + "\n") if __name__ == "__main__": argparser = argparse.ArgumentParser() argparser.add_argument("lang_src", help="source language code") argparser.add_argument("lang_tgt", help="target language code") argparser.add_argument("input_path_src", help="path to input file of source language") argparser.add_argument("input_path_tgt", help="path to input file of target language") argparser.add_argument("output_path_src", help="path to output file of source language") argparser.add_argument("output_path_tgt", help="path to output file of target language") argparser.add_argument("output_path_rejected", help="path to output rejection file") argparser.add_argument("--langdetect_threshold", type=int, default=40, help="check language of only lines of this number of characters or more") argparser.add_argument("--rejected_line_delimiter", type=str, default="@@@", help="string to use to delimit fields of rejection lines") args = argparser.parse_args() config = { "language_detection_threshold" : args.langdetect_threshold, "rejected_line_delimiter" : args.rejected_line_delimiter, } PostAlignmentCleaner(args.lang_src, args.lang_tgt, config=config).clean(args.input_path_src, args.input_path_tgt, args.output_path_src, args.output_path_tgt, args.output_path_rejected)