post_alignment_cleaner.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import argparse
  2. import logging
  3. from toolchain.common.language_detector import LanguageDetector
  4. logger = logging.getLogger(__name__)
  5. class PostAlignmentCleaner:
  6. DEFAULT_LANGUAGE_DETECTION_THRESHOLD = 40
  7. DEFAULT_REJECTED_LINE_DELIMITER = "@@@"
  8. REJECTION_EMPTY = "empty_segment"
  9. REJECTION_NONALPHA = "nonalpha"
  10. REJECTION_UNEXPECTED_LANGUAGE = "unexpected_language_[{0}:{1}]"
  11. def __init__(self, lang_src, lang_tgt, config={}, language_detector=LanguageDetector()):
  12. self.lang_src = lang_src
  13. self.lang_tgt = lang_tgt
  14. self.language_detector = language_detector
  15. self.language_detection_threshold = int(config.get("language_detection_threshold", self.DEFAULT_LANGUAGE_DETECTION_THRESHOLD))
  16. rejected_line_delimiter = config.get("rejected_line_delimiter", self.DEFAULT_REJECTED_LINE_DELIMITER)
  17. self.rejected_line_template = rejected_line_delimiter.join(["{0}", "{1}", "{2}"])
  18. def clean(self, input_path_src, input_path_tgt, output_path_src, output_path_tgt, output_path_rejected):
  19. logger.info("Cleaning {0} and {1} to {2} and {3} with rejections to {4}.".format(input_path_src, input_path_tgt, output_path_src, output_path_tgt, output_path_rejected))
  20. with open(input_path_src) as input_src, open(input_path_tgt) as input_tgt,\
  21. open(output_path_src, "w") as output_src, open(output_path_tgt, "w") as output_tgt, open(output_path_rejected, "w") as output_rejected:
  22. self.clean_text(input_src, input_tgt, output_src, output_tgt, output_rejected)
  23. def clean_text(self, input_src, input_tgt, output_src, output_tgt, output_rejected):
  24. for input_pair in zip(input_src, input_tgt):
  25. term_src, term_tgt = input_pair[0].rstrip("\n"), input_pair[1].rstrip("\n")
  26. should_include, message = self.should_include(term_src.strip(), term_tgt.strip())
  27. if should_include:
  28. self.write_file_line(output_src, term_src)
  29. self.write_file_line(output_tgt, term_tgt)
  30. else:
  31. self.write_file_line(output_rejected, message)
  32. def should_include(self, term_src, term_tgt):
  33. if not term_src or not term_tgt:
  34. message = self.rejected_line_template.format(self.REJECTION_EMPTY, term_src, term_tgt)
  35. return False, message
  36. if not self.contains_alpha(term_src) and not self.contains_alpha(term_tgt):
  37. message = self.rejected_line_template.format(self.REJECTION_NONALPHA, term_src, term_tgt)
  38. return False, message
  39. if len(term_src) >= self.language_detection_threshold or len(term_tgt) >= self.language_detection_threshold or term_src == term_tgt:
  40. detected_lang_src = self.language_detector.detect(term_src)
  41. detected_lang_tgt = self.language_detector.detect(term_tgt)
  42. if detected_lang_src != self.lang_src or detected_lang_tgt != self.lang_tgt:
  43. reason = self.REJECTION_UNEXPECTED_LANGUAGE.format(detected_lang_src, detected_lang_tgt)
  44. message = self.rejected_line_template.format(reason, term_src, term_tgt)
  45. return False, message
  46. return True, ""
  47. def contains_alpha(self, token):
  48. return any(c.isalpha() for c in token)
  49. def write_file_line(self, file, text):
  50. file.write(text + "\n")
  51. if __name__ == "__main__":
  52. argparser = argparse.ArgumentParser()
  53. argparser.add_argument("lang_src", help="source language code")
  54. argparser.add_argument("lang_tgt", help="target language code")
  55. argparser.add_argument("input_path_src", help="path to input file of source language")
  56. argparser.add_argument("input_path_tgt", help="path to input file of target language")
  57. argparser.add_argument("output_path_src", help="path to output file of source language")
  58. argparser.add_argument("output_path_tgt", help="path to output file of target language")
  59. argparser.add_argument("output_path_rejected", help="path to output rejection file")
  60. argparser.add_argument("--langdetect_threshold", type=int, default=40, help="check language of only lines of this number of characters or more")
  61. argparser.add_argument("--rejected_line_delimiter", type=str, default="@@@", help="string to use to delimit fields of rejection lines")
  62. args = argparser.parse_args()
  63. config = {
  64. "language_detection_threshold" : args.langdetect_threshold,
  65. "rejected_line_delimiter" : args.rejected_line_delimiter,
  66. }
  67. PostAlignmentCleaner(args.lang_src, args.lang_tgt, config=config).clean(args.input_path_src, args.input_path_tgt,
  68. args.output_path_src, args.output_path_tgt, args.output_path_rejected)