Skip to content

Commit 67996f8

Browse files
Internal change
PiperOrigin-RevId: 317898942
1 parent ee35a03 commit 67996f8

File tree

4 files changed

+268
-7
lines changed

4 files changed

+268
-7
lines changed

official/nlp/bert/input_pipeline.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,39 @@ def _select_data_from_record(record):
247247
dataset = dataset.batch(batch_size, drop_remainder=True)
248248
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
249249
return dataset
250+
251+
252+
def create_retrieval_dataset(file_path,
253+
seq_length,
254+
batch_size,
255+
input_pipeline_context=None):
256+
"""Creates input dataset from (tf)records files for scoring."""
257+
name_to_features = {
258+
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
259+
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
260+
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
261+
'int_iden': tf.io.FixedLenFeature([1], tf.int64),
262+
}
263+
dataset = single_file_dataset(file_path, name_to_features)
264+
265+
# The dataset is always sharded by number of hosts.
266+
# num_input_pipelines is the number of hosts rather than number of cores.
267+
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
268+
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
269+
input_pipeline_context.input_pipeline_id)
270+
271+
def _select_data_from_record(record):
272+
x = {
273+
'input_word_ids': record['input_ids'],
274+
'input_mask': record['input_mask'],
275+
'input_type_ids': record['segment_ids']
276+
}
277+
y = record['int_iden']
278+
return (x, y)
279+
280+
dataset = dataset.map(
281+
_select_data_from_record,
282+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
283+
dataset = dataset.batch(batch_size, drop_remainder=False)
284+
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
285+
return dataset

official/nlp/data/classifier_data_lib.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,13 @@
3333
class InputExample(object):
3434
"""A single training/test example for simple sequence classification."""
3535

36-
def __init__(self, guid, text_a, text_b=None, label=None, weight=None):
36+
def __init__(self,
37+
guid,
38+
text_a,
39+
text_b=None,
40+
label=None,
41+
weight=None,
42+
int_iden=None):
3743
"""Constructs a InputExample.
3844
3945
Args:
@@ -46,12 +52,15 @@ def __init__(self, guid, text_a, text_b=None, label=None, weight=None):
4652
specified for train and dev examples, but not for test examples.
4753
weight: (Optional) float. The weight of the example to be used during
4854
training.
55+
int_iden: (Optional) int. The int identification number of example in the
56+
corpus.
4957
"""
5058
self.guid = guid
5159
self.text_a = text_a
5260
self.text_b = text_b
5361
self.label = label
5462
self.weight = weight
63+
self.int_iden = int_iden
5564

5665

5766
class InputFeatures(object):
@@ -63,13 +72,15 @@ def __init__(self,
6372
segment_ids,
6473
label_id,
6574
is_real_example=True,
66-
weight=None):
75+
weight=None,
76+
int_iden=None):
6777
self.input_ids = input_ids
6878
self.input_mask = input_mask
6979
self.segment_ids = segment_ids
7080
self.label_id = label_id
7181
self.is_real_example = is_real_example
7282
self.weight = weight
83+
self.int_iden = int_iden
7384

7485

7586
class DataProcessor(object):
@@ -908,16 +919,19 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
908919
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
909920
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
910921
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
911-
logging.info("label: %s (id = %d)", example.label, label_id)
922+
logging.info("label: %s (id = %s)", example.label, str(label_id))
912923
logging.info("weight: %s", example.weight)
924+
logging.info("int_iden: %s", str(example.int_iden))
913925

914926
feature = InputFeatures(
915927
input_ids=input_ids,
916928
input_mask=input_mask,
917929
segment_ids=segment_ids,
918930
label_id=label_id,
919931
is_real_example=True,
920-
weight=example.weight)
932+
weight=example.weight,
933+
int_iden=example.int_iden)
934+
921935
return feature
922936

923937

@@ -953,12 +967,14 @@ def create_float_feature(values):
953967
features["segment_ids"] = create_int_feature(feature.segment_ids)
954968
if label_type is not None and label_type == float:
955969
features["label_ids"] = create_float_feature([feature.label_id])
956-
else:
970+
elif feature.label_id is not None:
957971
features["label_ids"] = create_int_feature([feature.label_id])
958972
features["is_real_example"] = create_int_feature(
959973
[int(feature.is_real_example)])
960974
if feature.weight is not None:
961975
features["weight"] = create_float_feature([feature.weight])
976+
if feature.int_iden is not None:
977+
features["int_iden"] = create_int_feature([feature.int_iden])
962978

963979
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
964980
writer.write(tf_example.SerializeToString())

official/nlp/data/create_finetuning_data.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import tensorflow as tf
2828
from official.nlp.bert import tokenization
2929
from official.nlp.data import classifier_data_lib
30+
from official.nlp.data import sentence_retrieval_lib
3031
# word-piece tokenizer based squad_lib
3132
from official.nlp.data import squad_lib as squad_lib_wp
3233
# sentence-piece tokenizer based squad_lib
@@ -36,7 +37,7 @@
3637

3738
flags.DEFINE_enum(
3839
"fine_tuning_task_type", "classification",
39-
["classification", "regression", "squad"],
40+
["classification", "regression", "squad", "retrieval"],
4041
"The name of the BERT fine tuning task for which data "
4142
"will be generated..")
4243

@@ -55,6 +56,9 @@
5556
"only and for XNLI is all languages combined. Same for "
5657
"PAWS-X.")
5758

59+
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
60+
"The name of sentence retrieval task for scoring")
61+
5862
# XNLI task specific flag.
5963
flags.DEFINE_string(
6064
"xnli_language", "en",
@@ -246,6 +250,39 @@ def generate_squad_dataset():
246250
FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative)
247251

248252

253+
def generate_retrieval_dataset():
254+
"""Generate retrieval test and dev dataset and returns input meta data."""
255+
assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name)
256+
if FLAGS.tokenizer_impl == "word_piece":
257+
tokenizer = tokenization.FullTokenizer(
258+
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
259+
processor_text_fn = tokenization.convert_to_unicode
260+
else:
261+
assert FLAGS.tokenizer_impl == "sentence_piece"
262+
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
263+
processor_text_fn = functools.partial(
264+
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
265+
266+
processors = {
267+
"bucc": sentence_retrieval_lib.BuccProcessor,
268+
"tatoeba": sentence_retrieval_lib.TatoebaProcessor,
269+
}
270+
271+
task_name = FLAGS.retrieval_task_name.lower()
272+
if task_name not in processors:
273+
raise ValueError("Task not found: %s" % task_name)
274+
275+
processor = processors[task_name](process_text_fn=processor_text_fn)
276+
277+
return sentence_retrieval_lib.generate_sentence_retrevial_tf_record(
278+
processor,
279+
FLAGS.input_data_dir,
280+
tokenizer,
281+
FLAGS.eval_data_output_path,
282+
FLAGS.test_data_output_path,
283+
FLAGS.max_seq_length)
284+
285+
249286
def main(_):
250287
if FLAGS.tokenizer_impl == "word_piece":
251288
if not FLAGS.vocab_file:
@@ -257,10 +294,15 @@ def main(_):
257294
raise ValueError(
258295
"FLAG sp_model_file for sentence-piece tokenizer is not specified.")
259296

297+
if FLAGS.fine_tuning_task_type != "retrieval":
298+
flags.mark_flag_as_required("train_data_output_path")
299+
260300
if FLAGS.fine_tuning_task_type == "classification":
261301
input_meta_data = generate_classifier_dataset()
262302
elif FLAGS.fine_tuning_task_type == "regression":
263303
input_meta_data = generate_regression_dataset()
304+
elif FLAGS.fine_tuning_task_type == "retrieval":
305+
input_meta_data = generate_retrieval_dataset()
264306
else:
265307
input_meta_data = generate_squad_dataset()
266308

@@ -270,6 +312,5 @@ def main(_):
270312

271313

272314
if __name__ == "__main__":
273-
flags.mark_flag_as_required("train_data_output_path")
274315
flags.mark_flag_as_required("meta_data_file_path")
275316
app.run(main)
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""BERT library to process data for cross lingual sentence retrieval task."""
16+
17+
import os
18+
19+
from absl import logging
20+
from official.nlp.bert import tokenization
21+
from official.nlp.data import classifier_data_lib
22+
23+
24+
class BuccProcessor(classifier_data_lib.DataProcessor):
25+
"""Procssor for Xtreme BUCC data set."""
26+
supported_languages = ["de", "fr", "ru", "zh"]
27+
28+
def __init__(self,
29+
process_text_fn=tokenization.convert_to_unicode):
30+
super(BuccProcessor, self).__init__(process_text_fn)
31+
self.languages = BuccProcessor.supported_languages
32+
33+
def get_dev_examples(self, data_dir, file_pattern):
34+
return self._create_examples(
35+
self._read_tsv(os.path.join(data_dir, file_pattern.format("dev"))),
36+
"sample")
37+
38+
def get_test_examples(self, data_dir, file_pattern):
39+
return self._create_examples(
40+
self._read_tsv(os.path.join(data_dir, file_pattern.format("test"))),
41+
"test")
42+
43+
@staticmethod
44+
def get_processor_name():
45+
"""See base class."""
46+
return "BUCC"
47+
48+
def _create_examples(self, lines, set_type):
49+
"""Creates examples for the training and dev sets."""
50+
examples = []
51+
for (i, line) in enumerate(lines):
52+
guid = "%s-%s" % (set_type, i)
53+
int_iden = int(line[0].split("-")[1])
54+
text_a = self.process_text_fn(line[1])
55+
examples.append(
56+
classifier_data_lib.InputExample(
57+
guid=guid, text_a=text_a, int_iden=int_iden))
58+
return examples
59+
60+
61+
class TatoebaProcessor(classifier_data_lib.DataProcessor):
62+
"""Procssor for Xtreme Tatoeba data set."""
63+
supported_languages = [
64+
"af", "ar", "bg", "bn", "de", "el", "es", "et", "eu", "fa", "fi", "fr",
65+
"he", "hi", "hu", "id", "it", "ja", "jv", "ka", "kk", "ko", "ml", "mr",
66+
"nl", "pt", "ru", "sw", "ta", "te", "th", "tl", "tr", "ur", "vi", "zh"
67+
]
68+
69+
def __init__(self,
70+
process_text_fn=tokenization.convert_to_unicode):
71+
super(TatoebaProcessor, self).__init__(process_text_fn)
72+
self.languages = TatoebaProcessor.supported_languages
73+
74+
def get_test_examples(self, data_dir, file_path):
75+
return self._create_examples(
76+
self._read_tsv(os.path.join(data_dir, file_path)), "test")
77+
78+
@staticmethod
79+
def get_processor_name():
80+
"""See base class."""
81+
return "TATOEBA"
82+
83+
def _create_examples(self, lines, set_type):
84+
"""Creates examples for the training and dev sets."""
85+
examples = []
86+
for (i, line) in enumerate(lines):
87+
guid = "%s-%s" % (set_type, i)
88+
text_a = self.process_text_fn(line[0])
89+
examples.append(
90+
classifier_data_lib.InputExample(
91+
guid=guid, text_a=text_a, int_iden=i))
92+
return examples
93+
94+
95+
def generate_sentence_retrevial_tf_record(processor,
96+
data_dir,
97+
tokenizer,
98+
eval_data_output_path=None,
99+
test_data_output_path=None,
100+
max_seq_length=128):
101+
"""Generates the tf records for retrieval tasks.
102+
103+
Args:
104+
processor: Input processor object to be used for generating data. Subclass
105+
of `DataProcessor`.
106+
data_dir: Directory that contains train/eval data to process. Data files
107+
should be in from.
108+
tokenizer: The tokenizer to be applied on the data.
109+
eval_data_output_path: Output to which processed tf record for evaluation
110+
will be saved.
111+
test_data_output_path: Output to which processed tf record for testing
112+
will be saved. Must be a pattern template with {} if processor has
113+
language specific test data.
114+
max_seq_length: Maximum sequence length of the to be generated
115+
training/eval data.
116+
117+
Returns:
118+
A dictionary containing input meta data.
119+
"""
120+
assert eval_data_output_path or test_data_output_path
121+
122+
if processor.get_processor_name() == "BUCC":
123+
path_pattern = "{}-en.{{}}.{}"
124+
125+
if processor.get_processor_name() == "TATOEBA":
126+
path_pattern = "{}-en.{}"
127+
128+
meta_data = {
129+
"processor_type": processor.get_processor_name(),
130+
"max_seq_length": max_seq_length,
131+
"number_eval_data": {},
132+
"number_test_data": {},
133+
}
134+
logging.info("Start to process %s task data", processor.get_processor_name())
135+
136+
for lang_a in processor.languages:
137+
for lang_b in [lang_a, "en"]:
138+
if eval_data_output_path:
139+
eval_input_data_examples = processor.get_dev_examples(
140+
data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
141+
142+
num_eval_data = len(eval_input_data_examples)
143+
logging.info("Processing %d dev examples of %s-en.%s", num_eval_data,
144+
lang_a, lang_b)
145+
output_file = os.path.join(
146+
eval_data_output_path,
147+
"{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "dev"))
148+
classifier_data_lib.file_based_convert_examples_to_features(
149+
eval_input_data_examples, None, max_seq_length, tokenizer,
150+
output_file, None)
151+
meta_data["number_eval_data"][f"{lang_a}-en.{lang_b}"] = num_eval_data
152+
153+
if test_data_output_path:
154+
test_input_data_examples = processor.get_test_examples(
155+
data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
156+
157+
num_test_data = len(test_input_data_examples)
158+
logging.info("Processing %d test examples of %s-en.%s", num_test_data,
159+
lang_a, lang_b)
160+
output_file = os.path.join(
161+
test_data_output_path,
162+
"{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "test"))
163+
classifier_data_lib.file_based_convert_examples_to_features(
164+
test_input_data_examples, None, max_seq_length, tokenizer,
165+
output_file, None)
166+
meta_data["number_test_data"][f"{lang_a}-en.{lang_b}"] = num_test_data
167+
168+
return meta_data

0 commit comments

Comments
 (0)