123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- import argparse
- import logging
- from toolchain.common.language_detector import LanguageDetector
- logger = logging.getLogger(__name__)
- class PostAlignmentCleaner:
- DEFAULT_LANGUAGE_DETECTION_THRESHOLD = 40
- DEFAULT_REJECTED_LINE_DELIMITER = "@@@"
- REJECTION_EMPTY = "empty_segment"
- REJECTION_NONALPHA = "nonalpha"
- REJECTION_UNEXPECTED_LANGUAGE = "unexpected_language_[{0}:{1}]"
- def __init__(self, lang_src, lang_tgt, config={}, language_detector=LanguageDetector()):
- self.lang_src = lang_src
- self.lang_tgt = lang_tgt
- self.language_detector = language_detector
- self.language_detection_threshold = int(config.get("language_detection_threshold", self.DEFAULT_LANGUAGE_DETECTION_THRESHOLD))
- rejected_line_delimiter = config.get("rejected_line_delimiter", self.DEFAULT_REJECTED_LINE_DELIMITER)
- self.rejected_line_template = rejected_line_delimiter.join(["{0}", "{1}", "{2}"])
- def clean(self, input_path_src, input_path_tgt, output_path_src, output_path_tgt, output_path_rejected):
- logger.info("Cleaning {0} and {1} to {2} and {3} with rejections to {4}.".format(input_path_src, input_path_tgt, output_path_src, output_path_tgt, output_path_rejected))
- with open(input_path_src) as input_src, open(input_path_tgt) as input_tgt,\
- open(output_path_src, "w") as output_src, open(output_path_tgt, "w") as output_tgt, open(output_path_rejected, "w") as output_rejected:
- self.clean_text(input_src, input_tgt, output_src, output_tgt, output_rejected)
- def clean_text(self, input_src, input_tgt, output_src, output_tgt, output_rejected):
- for input_pair in zip(input_src, input_tgt):
- term_src, term_tgt = input_pair[0].rstrip("\n"), input_pair[1].rstrip("\n")
- should_include, message = self.should_include(term_src.strip(), term_tgt.strip())
- if should_include:
- self.write_file_line(output_src, term_src)
- self.write_file_line(output_tgt, term_tgt)
- else:
- self.write_file_line(output_rejected, message)
- def should_include(self, term_src, term_tgt):
- if not term_src or not term_tgt:
- message = self.rejected_line_template.format(self.REJECTION_EMPTY, term_src, term_tgt)
- return False, message
- if not self.contains_alpha(term_src) and not self.contains_alpha(term_tgt):
- message = self.rejected_line_template.format(self.REJECTION_NONALPHA, term_src, term_tgt)
- return False, message
- if len(term_src) >= self.language_detection_threshold or len(term_tgt) >= self.language_detection_threshold or term_src == term_tgt:
- detected_lang_src = self.language_detector.detect(term_src)
- detected_lang_tgt = self.language_detector.detect(term_tgt)
- if detected_lang_src != self.lang_src or detected_lang_tgt != self.lang_tgt:
- reason = self.REJECTION_UNEXPECTED_LANGUAGE.format(detected_lang_src, detected_lang_tgt)
- message = self.rejected_line_template.format(reason, term_src, term_tgt)
- return False, message
- return True, ""
- def contains_alpha(self, token):
- return any(c.isalpha() for c in token)
- def write_file_line(self, file, text):
- file.write(text + "\n")
- if __name__ == "__main__":
- argparser = argparse.ArgumentParser()
- argparser.add_argument("lang_src", help="source language code")
- argparser.add_argument("lang_tgt", help="target language code")
- argparser.add_argument("input_path_src", help="path to input file of source language")
- argparser.add_argument("input_path_tgt", help="path to input file of target language")
- argparser.add_argument("output_path_src", help="path to output file of source language")
- argparser.add_argument("output_path_tgt", help="path to output file of target language")
- argparser.add_argument("output_path_rejected", help="path to output rejection file")
- argparser.add_argument("--langdetect_threshold", type=int, default=40, help="check language of only lines of this number of characters or more")
- argparser.add_argument("--rejected_line_delimiter", type=str, default="@@@", help="string to use to delimit fields of rejection lines")
- args = argparser.parse_args()
- config = {
- "language_detection_threshold" : args.langdetect_threshold,
- "rejected_line_delimiter" : args.rejected_line_delimiter,
- }
- PostAlignmentCleaner(args.lang_src, args.lang_tgt, config=config).clean(args.input_path_src, args.input_path_tgt,
- args.output_path_src, args.output_path_tgt, args.output_path_rejected)
|