Skip to content

Commit 169e405

Browse files
Internal change
PiperOrigin-RevId: 381152422
1 parent 4c99ab7 commit 169e405

File tree

3 files changed

+128
-35
lines changed

3 files changed

+128
-35
lines changed

official/nlp/data/classifier_data_lib.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,20 +1287,17 @@ class SuperGLUEDataProcessor(DataProcessor):
12871287

12881288
def get_train_examples(self, data_dir):
12891289
"""See base class."""
1290-
return self._create_examples(
1291-
self._read_jsonl(os.path.join(data_dir, "train.jsonl")), "train")
1290+
return self._create_examples_tfds("train")
12921291

12931292
def get_dev_examples(self, data_dir):
12941293
"""See base class."""
1295-
return self._create_examples(
1296-
self._read_jsonl(os.path.join(data_dir, "val.jsonl")), "dev")
1294+
return self._create_examples_tfds("validation")
12971295

12981296
def get_test_examples(self, data_dir):
12991297
"""See base class."""
1300-
return self._create_examples(
1301-
self._read_jsonl(os.path.join(data_dir, "test.jsonl")), "test")
1298+
return self._create_examples_tfds("test")
13021299

1303-
def _create_examples(self, lines, set_type):
1300+
def _create_examples_tfds(self, set_type):
13041301
"""Creates examples for the training/dev/test sets."""
13051302
raise NotImplementedError()
13061303

@@ -1317,17 +1314,18 @@ def get_processor_name():
13171314
"""See base class."""
13181315
return "BoolQ"
13191316

1320-
def _create_examples(self, lines, set_type):
1317+
def _create_examples_tfds(self, set_type):
13211318
"""Creates examples for the training/dev/test sets."""
1319+
dataset = tfds.load(
1320+
"super_glue/boolq", split=set_type, try_gcs=True).as_numpy_iterator()
13221321
examples = []
1323-
for line in lines:
1324-
guid = "%s-%s" % (set_type, self.process_text_fn(str(line["idx"])))
1325-
text_a = self.process_text_fn(line["question"])
1326-
text_b = self.process_text_fn(line["passage"])
1327-
if set_type == "test":
1328-
label = "False"
1329-
else:
1330-
label = str(line["label"])
1322+
for example in dataset:
1323+
guid = "%s-%s" % (set_type, self.process_text_fn(str(example["idx"])))
1324+
text_a = self.process_text_fn(example["question"])
1325+
text_b = self.process_text_fn(example["passage"])
1326+
label = "False"
1327+
if set_type != "test":
1328+
label = self.get_labels()[example["label"]]
13311329
examples.append(
13321330
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
13331331
return examples
@@ -1345,17 +1343,18 @@ def get_processor_name():
13451343
"""See base class."""
13461344
return "CB"
13471345

1348-
def _create_examples(self, lines, set_type):
1346+
def _create_examples_tfds(self, set_type):
13491347
"""Creates examples for the training/dev/test sets."""
1348+
dataset = tfds.load(
1349+
"super_glue/cb", split=set_type, try_gcs=True).as_numpy_iterator()
13501350
examples = []
1351-
for line in lines:
1352-
guid = "%s-%s" % (set_type, self.process_text_fn(str(line["idx"])))
1353-
text_a = self.process_text_fn(line["premise"])
1354-
text_b = self.process_text_fn(line["hypothesis"])
1355-
if set_type == "test":
1356-
label = "entailment"
1357-
else:
1358-
label = self.process_text_fn(line["label"])
1351+
for example in dataset:
1352+
guid = "%s-%s" % (set_type, self.process_text_fn(str(example["idx"])))
1353+
text_a = self.process_text_fn(example["premise"])
1354+
text_b = self.process_text_fn(example["hypothesis"])
1355+
label = "entailment"
1356+
if set_type != "test":
1357+
label = self.get_labels()[example["label"]]
13591358
examples.append(
13601359
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
13611360
return examples
@@ -1375,17 +1374,18 @@ def get_processor_name():
13751374
"""See base class."""
13761375
return "RTESuperGLUE"
13771376

1378-
def _create_examples(self, lines, set_type):
1377+
def _create_examples_tfds(self, set_type):
13791378
"""Creates examples for the training/dev/test sets."""
13801379
examples = []
1381-
for i, line in enumerate(lines):
1382-
guid = "%s-%s" % (set_type, i)
1383-
text_a = self.process_text_fn(line["premise"])
1384-
text_b = self.process_text_fn(line["hypothesis"])
1385-
if set_type == "test":
1386-
label = "entailment"
1387-
else:
1388-
label = self.process_text_fn(line["label"])
1380+
dataset = tfds.load(
1381+
"super_glue/rte", split=set_type, try_gcs=True).as_numpy_iterator()
1382+
for example in dataset:
1383+
guid = "%s-%s" % (set_type, self.process_text_fn(str(example["idx"])))
1384+
text_a = self.process_text_fn(example["premise"])
1385+
text_b = self.process_text_fn(example["hypothesis"])
1386+
label = "entailment"
1387+
if set_type != "test":
1388+
label = self.get_labels()[example["label"]]
13891389
examples.append(
13901390
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
13911391
return examples
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for third_party.tensorflow_models.official.nlp.data.classifier_data_lib."""
16+
17+
import os
18+
import tempfile
19+
20+
from absl.testing import parameterized
21+
import tensorflow as tf
22+
import tensorflow_datasets as tfds
23+
24+
from official.nlp.bert import tokenization
25+
from official.nlp.data import classifier_data_lib
26+
27+
28+
def decode_record(record, name_to_features):
29+
"""Decodes a record to a TensorFlow example."""
30+
return tf.io.parse_single_example(record, name_to_features)
31+
32+
33+
class BertClassifierLibTest(tf.test.TestCase, parameterized.TestCase):
34+
35+
def setUp(self):
36+
super(BertClassifierLibTest, self).setUp()
37+
self.model_dir = self.get_temp_dir()
38+
self.processors = {
39+
"CB": classifier_data_lib.CBProcessor,
40+
"SUPERGLUE-RTE": classifier_data_lib.SuperGLUERTEProcessor,
41+
"BOOLQ": classifier_data_lib.BoolQProcessor,
42+
}
43+
44+
vocab_tokens = [
45+
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
46+
"##ing", ","
47+
]
48+
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
49+
vocab_writer.write("".join([x + "\n" for x in vocab_tokens
50+
]).encode("utf-8"))
51+
vocab_file = vocab_writer.name
52+
self.tokenizer = tokenization.FullTokenizer(vocab_file)
53+
54+
@parameterized.parameters(
55+
{"task_type": "CB"},
56+
{"task_type": "BOOLQ"},
57+
{"task_type": "SUPERGLUE-RTE"},
58+
)
59+
def test_generate_dataset_from_tfds_processor(self, task_type):
60+
with tfds.testing.mock_data(num_examples=5):
61+
output_path = os.path.join(self.model_dir, task_type)
62+
63+
processor = self.processors[task_type]()
64+
65+
classifier_data_lib.generate_tf_record_from_data_file(
66+
processor,
67+
None,
68+
self.tokenizer,
69+
train_data_output_path=output_path,
70+
eval_data_output_path=output_path,
71+
test_data_output_path=output_path)
72+
files = tf.io.gfile.glob(output_path)
73+
self.assertNotEmpty(files)
74+
75+
train_dataset = tf.data.TFRecordDataset(output_path)
76+
seq_length = 128
77+
label_type = tf.int64
78+
name_to_features = {
79+
"input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
80+
"input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
81+
"segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
82+
"label_ids": tf.io.FixedLenFeature([], label_type),
83+
}
84+
train_dataset = train_dataset.map(
85+
lambda record: decode_record(record, name_to_features))
86+
87+
# If data is retrieved without error, then all requirements
88+
# including data type/shapes are met.
89+
_ = next(iter(train_dataset))
90+
91+
92+
if __name__ == "__main__":
93+
tf.test.main()

official/nlp/data/create_finetuning_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def generate_classifier_dataset():
175175
"""Generates classifier dataset and returns input meta data."""
176176
if FLAGS.classification_task_name in [
177177
"COLA", "WNLI", "SST-2", "MRPC", "QQP", "STS-B", "MNLI", "QNLI", "RTE",
178-
"AX"
178+
"AX", "SUPERGLUE-RTE", "CB", "BoolQ"
179179
]:
180180
assert not FLAGS.input_data_dir or FLAGS.tfds_params
181181
else:

0 commit comments

Comments
 (0)