Skip to content

Commit 6e5cbee

Browse files
Internal change
PiperOrigin-RevId: 381396116
1 parent fd505f4 commit 6e5cbee

File tree

3 files changed

+152
-3
lines changed

3 files changed

+152
-3
lines changed

official/nlp/data/classifier_data_lib.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,139 @@ def _create_examples_tfds(self, set_type):
12991299
return examples
13001300

13011301

1302+
class WiCInputExample(InputExample):
1303+
"""Processor for the WiC dataset (SuperGLUE version)."""
1304+
1305+
def __init__(self,
1306+
guid,
1307+
text_a,
1308+
text_b=None,
1309+
label=None,
1310+
word=None,
1311+
weight=None,
1312+
example_id=None):
1313+
"""A single training/test example for simple seq regression/classification."""
1314+
super(WiCInputExample, self).__init__(guid, text_a, text_b, label, weight,
1315+
example_id)
1316+
self.word = word
1317+
1318+
1319+
class WiCProcessor(DefaultGLUEDataProcessor):
1320+
"""Processor for the RTE dataset (SuperGLUE version)."""
1321+
1322+
def get_labels(self):
1323+
"""Not used."""
1324+
return []
1325+
1326+
@staticmethod
1327+
def get_processor_name():
1328+
"""See base class."""
1329+
return "RTESuperGLUE"
1330+
1331+
def _create_examples_tfds(self, set_type):
1332+
"""Creates examples for the training/dev/test sets."""
1333+
examples = []
1334+
dataset = tfds.load(
1335+
"super_glue/wic", split=set_type, try_gcs=True).as_numpy_iterator()
1336+
for example in dataset:
1337+
guid = "%s-%s" % (set_type, self.process_text_fn(str(example["idx"])))
1338+
text_a = self.process_text_fn(example["sentence1"])
1339+
text_b = self.process_text_fn(example["sentence2"])
1340+
word = self.process_text_fn(example["word"])
1341+
label = 0
1342+
if set_type != "test":
1343+
label = example["label"]
1344+
examples.append(
1345+
WiCInputExample(
1346+
guid=guid, text_a=text_a, text_b=text_b, word=word, label=label))
1347+
return examples
1348+
1349+
def featurize_example(self, ex_index, example, label_list, max_seq_length,
1350+
tokenizer):
1351+
"""Here we concate sentence1, sentence2, word together with [SEP] tokens."""
1352+
del label_list
1353+
tokens_a = tokenizer.tokenize(example.text_a)
1354+
tokens_b = tokenizer.tokenize(example.text_b)
1355+
tokens_word = tokenizer.tokenize(example.word)
1356+
1357+
# Modifies `tokens_a` and `tokens_b` in place so that the total
1358+
# length is less than the specified length.
1359+
# Account for [CLS], [SEP], [SEP], [SEP] with "- 4"
1360+
# Here we only pop out the first two sentence tokens.
1361+
_truncate_seq_pair(tokens_a, tokens_b,
1362+
max_seq_length - 4 - len(tokens_word))
1363+
1364+
seg_id_a = 0
1365+
seg_id_b = 1
1366+
seg_id_c = 2
1367+
seg_id_cls = 0
1368+
seg_id_pad = 0
1369+
1370+
tokens = []
1371+
segment_ids = []
1372+
tokens.append("[CLS]")
1373+
segment_ids.append(seg_id_cls)
1374+
for token in tokens_a:
1375+
tokens.append(token)
1376+
segment_ids.append(seg_id_a)
1377+
tokens.append("[SEP]")
1378+
segment_ids.append(seg_id_a)
1379+
1380+
for token in tokens_b:
1381+
tokens.append(token)
1382+
segment_ids.append(seg_id_b)
1383+
1384+
tokens.append("[SEP]")
1385+
segment_ids.append(seg_id_b)
1386+
1387+
for token in tokens_word:
1388+
tokens.append(token)
1389+
segment_ids.append(seg_id_c)
1390+
1391+
tokens.append("[SEP]")
1392+
segment_ids.append(seg_id_c)
1393+
1394+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
1395+
1396+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
1397+
# tokens are attended to.
1398+
input_mask = [1] * len(input_ids)
1399+
1400+
# Zero-pad up to the sequence length.
1401+
while len(input_ids) < max_seq_length:
1402+
input_ids.append(0)
1403+
input_mask.append(0)
1404+
segment_ids.append(seg_id_pad)
1405+
1406+
assert len(input_ids) == max_seq_length
1407+
assert len(input_mask) == max_seq_length
1408+
assert len(segment_ids) == max_seq_length
1409+
1410+
label_id = example.label
1411+
if ex_index < 5:
1412+
logging.info("*** Example ***")
1413+
logging.info("guid: %s", (example.guid))
1414+
logging.info("tokens: %s",
1415+
" ".join([tokenization.printable_text(x) for x in tokens]))
1416+
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
1417+
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
1418+
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
1419+
logging.info("label: %s (id = %s)", example.label, str(label_id))
1420+
logging.info("weight: %s", example.weight)
1421+
logging.info("example_id: %s", example.example_id)
1422+
1423+
feature = InputFeatures(
1424+
input_ids=input_ids,
1425+
input_mask=input_mask,
1426+
segment_ids=segment_ids,
1427+
label_id=label_id,
1428+
is_real_example=True,
1429+
weight=example.weight,
1430+
example_id=example.example_id)
1431+
1432+
return feature
1433+
1434+
13021435
def file_based_convert_examples_to_features(examples,
13031436
label_list,
13041437
max_seq_length,

official/nlp/data/classifier_data_lib_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def setUp(self):
3939
"CB": classifier_data_lib.CBProcessor,
4040
"SUPERGLUE-RTE": classifier_data_lib.SuperGLUERTEProcessor,
4141
"BOOLQ": classifier_data_lib.BoolQProcessor,
42+
"WIC": classifier_data_lib.WiCProcessor,
4243
}
4344

4445
vocab_tokens = [
@@ -55,6 +56,7 @@ def setUp(self):
5556
{"task_type": "CB"},
5657
{"task_type": "BOOLQ"},
5758
{"task_type": "SUPERGLUE-RTE"},
59+
{"task_type": "WIC"},
5860
)
5961
def test_generate_dataset_from_tfds_processor(self, task_type):
6062
with tfds.testing.mock_data(num_examples=5):

official/nlp/data/create_finetuning_data.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"classification_task_name", "MNLI", [
5151
"AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
5252
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X",
53-
"AX-g", "SUPERGLUE-RTE", "CB", "BoolQ"
53+
"AX-g", "SUPERGLUE-RTE", "CB", "BoolQ", "WIC"
5454
], "The name of the task to train BERT classifier. The "
5555
"difference between XTREME-XNLI and XNLI is: 1. the format "
5656
"of input tsv files; 2. the dev set for XTREME is english "
@@ -174,8 +174,20 @@
174174
def generate_classifier_dataset():
175175
"""Generates classifier dataset and returns input meta data."""
176176
if FLAGS.classification_task_name in [
177-
"COLA", "WNLI", "SST-2", "MRPC", "QQP", "STS-B", "MNLI", "QNLI", "RTE",
178-
"AX", "SUPERGLUE-RTE", "CB", "BoolQ"
177+
"COLA",
178+
"WNLI",
179+
"SST-2",
180+
"MRPC",
181+
"QQP",
182+
"STS-B",
183+
"MNLI",
184+
"QNLI",
185+
"RTE",
186+
"AX",
187+
"SUPERGLUE-RTE",
188+
"CB",
189+
"BoolQ",
190+
"WIC",
179191
]:
180192
assert not FLAGS.input_data_dir or FLAGS.tfds_params
181193
else:
@@ -254,6 +266,8 @@ def generate_classifier_dataset():
254266
classifier_data_lib.CBProcessor,
255267
"boolq":
256268
classifier_data_lib.BoolQProcessor,
269+
"wic":
270+
classifier_data_lib.WnliProcessor,
257271
}
258272
task_name = FLAGS.classification_task_name.lower()
259273
if task_name not in processors:

0 commit comments

Comments
 (0)