|
| 1 | +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""BERT library to process data for cross lingual sentence retrieval task.""" |
| 16 | + |
| 17 | +import os |
| 18 | + |
| 19 | +from absl import logging |
| 20 | +from official.nlp.bert import tokenization |
| 21 | +from official.nlp.data import classifier_data_lib |
| 22 | + |
| 23 | + |
| 24 | +class BuccProcessor(classifier_data_lib.DataProcessor): |
| 25 | + """Procssor for Xtreme BUCC data set.""" |
| 26 | + supported_languages = ["de", "fr", "ru", "zh"] |
| 27 | + |
| 28 | + def __init__(self, |
| 29 | + process_text_fn=tokenization.convert_to_unicode): |
| 30 | + super(BuccProcessor, self).__init__(process_text_fn) |
| 31 | + self.languages = BuccProcessor.supported_languages |
| 32 | + |
| 33 | + def get_dev_examples(self, data_dir, file_pattern): |
| 34 | + return self._create_examples( |
| 35 | + self._read_tsv(os.path.join(data_dir, file_pattern.format("dev"))), |
| 36 | + "sample") |
| 37 | + |
| 38 | + def get_test_examples(self, data_dir, file_pattern): |
| 39 | + return self._create_examples( |
| 40 | + self._read_tsv(os.path.join(data_dir, file_pattern.format("test"))), |
| 41 | + "test") |
| 42 | + |
| 43 | + @staticmethod |
| 44 | + def get_processor_name(): |
| 45 | + """See base class.""" |
| 46 | + return "BUCC" |
| 47 | + |
| 48 | + def _create_examples(self, lines, set_type): |
| 49 | + """Creates examples for the training and dev sets.""" |
| 50 | + examples = [] |
| 51 | + for (i, line) in enumerate(lines): |
| 52 | + guid = "%s-%s" % (set_type, i) |
| 53 | + int_iden = int(line[0].split("-")[1]) |
| 54 | + text_a = self.process_text_fn(line[1]) |
| 55 | + examples.append( |
| 56 | + classifier_data_lib.InputExample( |
| 57 | + guid=guid, text_a=text_a, int_iden=int_iden)) |
| 58 | + return examples |
| 59 | + |
| 60 | + |
| 61 | +class TatoebaProcessor(classifier_data_lib.DataProcessor): |
| 62 | + """Procssor for Xtreme Tatoeba data set.""" |
| 63 | + supported_languages = [ |
| 64 | + "af", "ar", "bg", "bn", "de", "el", "es", "et", "eu", "fa", "fi", "fr", |
| 65 | + "he", "hi", "hu", "id", "it", "ja", "jv", "ka", "kk", "ko", "ml", "mr", |
| 66 | + "nl", "pt", "ru", "sw", "ta", "te", "th", "tl", "tr", "ur", "vi", "zh" |
| 67 | + ] |
| 68 | + |
| 69 | + def __init__(self, |
| 70 | + process_text_fn=tokenization.convert_to_unicode): |
| 71 | + super(TatoebaProcessor, self).__init__(process_text_fn) |
| 72 | + self.languages = TatoebaProcessor.supported_languages |
| 73 | + |
| 74 | + def get_test_examples(self, data_dir, file_path): |
| 75 | + return self._create_examples( |
| 76 | + self._read_tsv(os.path.join(data_dir, file_path)), "test") |
| 77 | + |
| 78 | + @staticmethod |
| 79 | + def get_processor_name(): |
| 80 | + """See base class.""" |
| 81 | + return "TATOEBA" |
| 82 | + |
| 83 | + def _create_examples(self, lines, set_type): |
| 84 | + """Creates examples for the training and dev sets.""" |
| 85 | + examples = [] |
| 86 | + for (i, line) in enumerate(lines): |
| 87 | + guid = "%s-%s" % (set_type, i) |
| 88 | + text_a = self.process_text_fn(line[0]) |
| 89 | + examples.append( |
| 90 | + classifier_data_lib.InputExample( |
| 91 | + guid=guid, text_a=text_a, int_iden=i)) |
| 92 | + return examples |
| 93 | + |
| 94 | + |
| 95 | +def generate_sentence_retrevial_tf_record(processor, |
| 96 | + data_dir, |
| 97 | + tokenizer, |
| 98 | + eval_data_output_path=None, |
| 99 | + test_data_output_path=None, |
| 100 | + max_seq_length=128): |
| 101 | + """Generates the tf records for retrieval tasks. |
| 102 | +
|
| 103 | + Args: |
| 104 | + processor: Input processor object to be used for generating data. Subclass |
| 105 | + of `DataProcessor`. |
| 106 | + data_dir: Directory that contains train/eval data to process. Data files |
| 107 | + should be in from. |
| 108 | + tokenizer: The tokenizer to be applied on the data. |
| 109 | + eval_data_output_path: Output to which processed tf record for evaluation |
| 110 | + will be saved. |
| 111 | + test_data_output_path: Output to which processed tf record for testing |
| 112 | + will be saved. Must be a pattern template with {} if processor has |
| 113 | + language specific test data. |
| 114 | + max_seq_length: Maximum sequence length of the to be generated |
| 115 | + training/eval data. |
| 116 | +
|
| 117 | + Returns: |
| 118 | + A dictionary containing input meta data. |
| 119 | + """ |
| 120 | + assert eval_data_output_path or test_data_output_path |
| 121 | + |
| 122 | + if processor.get_processor_name() == "BUCC": |
| 123 | + path_pattern = "{}-en.{{}}.{}" |
| 124 | + |
| 125 | + if processor.get_processor_name() == "TATOEBA": |
| 126 | + path_pattern = "{}-en.{}" |
| 127 | + |
| 128 | + meta_data = { |
| 129 | + "processor_type": processor.get_processor_name(), |
| 130 | + "max_seq_length": max_seq_length, |
| 131 | + "number_eval_data": {}, |
| 132 | + "number_test_data": {}, |
| 133 | + } |
| 134 | + logging.info("Start to process %s task data", processor.get_processor_name()) |
| 135 | + |
| 136 | + for lang_a in processor.languages: |
| 137 | + for lang_b in [lang_a, "en"]: |
| 138 | + if eval_data_output_path: |
| 139 | + eval_input_data_examples = processor.get_dev_examples( |
| 140 | + data_dir, os.path.join(path_pattern.format(lang_a, lang_b))) |
| 141 | + |
| 142 | + num_eval_data = len(eval_input_data_examples) |
| 143 | + logging.info("Processing %d dev examples of %s-en.%s", num_eval_data, |
| 144 | + lang_a, lang_b) |
| 145 | + output_file = os.path.join( |
| 146 | + eval_data_output_path, |
| 147 | + "{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "dev")) |
| 148 | + classifier_data_lib.file_based_convert_examples_to_features( |
| 149 | + eval_input_data_examples, None, max_seq_length, tokenizer, |
| 150 | + output_file, None) |
| 151 | + meta_data["number_eval_data"][f"{lang_a}-en.{lang_b}"] = num_eval_data |
| 152 | + |
| 153 | + if test_data_output_path: |
| 154 | + test_input_data_examples = processor.get_test_examples( |
| 155 | + data_dir, os.path.join(path_pattern.format(lang_a, lang_b))) |
| 156 | + |
| 157 | + num_test_data = len(test_input_data_examples) |
| 158 | + logging.info("Processing %d test examples of %s-en.%s", num_test_data, |
| 159 | + lang_a, lang_b) |
| 160 | + output_file = os.path.join( |
| 161 | + test_data_output_path, |
| 162 | + "{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "test")) |
| 163 | + classifier_data_lib.file_based_convert_examples_to_features( |
| 164 | + test_input_data_examples, None, max_seq_length, tokenizer, |
| 165 | + output_file, None) |
| 166 | + meta_data["number_test_data"][f"{lang_a}-en.{lang_b}"] = num_test_data |
| 167 | + |
| 168 | + return meta_data |
0 commit comments