Skip to content

Commit c73c012

Browse files
Internal change
PiperOrigin-RevId: 381367878
1 parent 88b2a35 commit c73c012

File tree

1 file changed

+51
-134
lines changed

1 file changed

+51
-134
lines changed

official/nlp/data/classifier_data_lib.py

Lines changed: 51 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,30 @@ def _read_jsonl(cls, input_file):
129129
lines.append(json.loads(json_str))
130130
return lines
131131

132+
def featurize_example(self, *kargs, **kwargs):
133+
"""Converts a single `InputExample` into a single `InputFeatures`."""
134+
return convert_single_example(*kargs, **kwargs)
135+
136+
137+
class DefaultGLUEDataProcessor(DataProcessor):
138+
"""Processor for the SuperGLUE dataset."""
139+
140+
def get_train_examples(self, data_dir):
141+
"""See base class."""
142+
return self._create_examples_tfds("train")
143+
144+
def get_dev_examples(self, data_dir):
145+
"""See base class."""
146+
return self._create_examples_tfds("validation")
147+
148+
def get_test_examples(self, data_dir):
149+
"""See base class."""
150+
return self._create_examples_tfds("test")
151+
152+
def _create_examples_tfds(self, set_type):
153+
"""Creates examples for the training/dev/test sets."""
154+
raise NotImplementedError()
155+
132156

133157
class AxProcessor(DataProcessor):
134158
"""Processor for the AX dataset (GLUE diagnostics dataset)."""
@@ -178,21 +202,9 @@ def _create_examples_tfds(self, dataset, set_type):
178202
return examples
179203

180204

181-
class ColaProcessor(DataProcessor):
205+
class ColaProcessor(DefaultGLUEDataProcessor):
182206
"""Processor for the CoLA data set (GLUE version)."""
183207

184-
def get_train_examples(self, data_dir):
185-
"""See base class."""
186-
return self._create_examples_tfds("train")
187-
188-
def get_dev_examples(self, data_dir):
189-
"""See base class."""
190-
return self._create_examples_tfds("validation")
191-
192-
def get_test_examples(self, data_dir):
193-
"""See base class."""
194-
return self._create_examples_tfds("test")
195-
196208
def get_labels(self):
197209
"""See base class."""
198210
return ["0", "1"]
@@ -315,21 +327,9 @@ def _create_examples_tfds(self, set_type):
315327
return examples
316328

317329

318-
class MrpcProcessor(DataProcessor):
330+
class MrpcProcessor(DefaultGLUEDataProcessor):
319331
"""Processor for the MRPC data set (GLUE version)."""
320332

321-
def get_train_examples(self, data_dir):
322-
"""See base class."""
323-
return self._create_examples_tfds("train")
324-
325-
def get_dev_examples(self, data_dir):
326-
"""See base class."""
327-
return self._create_examples_tfds("validation")
328-
329-
def get_test_examples(self, data_dir):
330-
"""See base class."""
331-
return self._create_examples_tfds("test")
332-
333333
def get_labels(self):
334334
"""See base class."""
335335
return ["0", "1"]
@@ -437,21 +437,9 @@ def get_processor_name():
437437
return "XTREME-PAWS-X"
438438

439439

440-
class QnliProcessor(DataProcessor):
440+
class QnliProcessor(DefaultGLUEDataProcessor):
441441
"""Processor for the QNLI data set (GLUE version)."""
442442

443-
def get_train_examples(self, data_dir):
444-
"""See base class."""
445-
return self._create_examples_tfds("train")
446-
447-
def get_dev_examples(self, data_dir):
448-
"""See base class."""
449-
return self._create_examples_tfds("validation")
450-
451-
def get_test_examples(self, data_dir):
452-
"""See base class."""
453-
return self._create_examples_tfds("test")
454-
455443
def get_labels(self):
456444
"""See base class."""
457445
return ["entailment", "not_entailment"]
@@ -480,21 +468,9 @@ def _create_examples_tfds(self, set_type):
480468
return examples
481469

482470

483-
class QqpProcessor(DataProcessor):
471+
class QqpProcessor(DefaultGLUEDataProcessor):
484472
"""Processor for the QQP data set (GLUE version)."""
485473

486-
def get_train_examples(self, data_dir):
487-
"""See base class."""
488-
return self._create_examples_tfds("train")
489-
490-
def get_dev_examples(self, data_dir):
491-
"""See base class."""
492-
return self._create_examples_tfds("validation")
493-
494-
def get_test_examples(self, data_dir):
495-
"""See base class."""
496-
return self._create_examples_tfds("test")
497-
498474
def get_labels(self):
499475
"""See base class."""
500476
return ["0", "1"]
@@ -523,21 +499,9 @@ def _create_examples_tfds(self, set_type):
523499
return examples
524500

525501

526-
class RteProcessor(DataProcessor):
502+
class RteProcessor(DefaultGLUEDataProcessor):
527503
"""Processor for the RTE data set (GLUE version)."""
528504

529-
def get_train_examples(self, data_dir):
530-
"""See base class."""
531-
return self._create_examples_tfds("train")
532-
533-
def get_dev_examples(self, data_dir):
534-
"""See base class."""
535-
return self._create_examples_tfds("validation")
536-
537-
def get_test_examples(self, data_dir):
538-
"""See base class."""
539-
return self._create_examples_tfds("test")
540-
541505
def get_labels(self):
542506
"""See base class."""
543507
# All datasets are converted to 2-class split, where for 3-class datasets we
@@ -568,21 +532,9 @@ def _create_examples_tfds(self, set_type):
568532
return examples
569533

570534

571-
class SstProcessor(DataProcessor):
535+
class SstProcessor(DefaultGLUEDataProcessor):
572536
"""Processor for the SST-2 data set (GLUE version)."""
573537

574-
def get_train_examples(self, data_dir):
575-
"""See base class."""
576-
return self._create_examples_tfds("train")
577-
578-
def get_dev_examples(self, data_dir):
579-
"""See base class."""
580-
return self._create_examples_tfds("validation")
581-
582-
def get_test_examples(self, data_dir):
583-
"""See base class."""
584-
return self._create_examples_tfds("test")
585-
586538
def get_labels(self):
587539
"""See base class."""
588540
return ["0", "1"]
@@ -609,7 +561,7 @@ def _create_examples_tfds(self, set_type):
609561
return examples
610562

611563

612-
class StsBProcessor(DataProcessor):
564+
class StsBProcessor(DefaultGLUEDataProcessor):
613565
"""Processor for the STS-B data set (GLUE version)."""
614566

615567
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
@@ -618,18 +570,6 @@ def __init__(self, process_text_fn=tokenization.convert_to_unicode):
618570
self.label_type = float
619571
self._labels = None
620572

621-
def get_train_examples(self, data_dir):
622-
"""See base class."""
623-
return self._create_examples_tfds("train")
624-
625-
def get_dev_examples(self, data_dir):
626-
"""See base class."""
627-
return self._create_examples_tfds("validation")
628-
629-
def get_test_examples(self, data_dir):
630-
"""See base class."""
631-
return self._create_examples_tfds("test")
632-
633573
def _create_examples_tfds(self, set_type):
634574
"""Creates examples for the training/dev/test sets."""
635575
dataset = tfds.load(
@@ -786,21 +726,9 @@ def _create_examples(self, split_name, set_type):
786726
return examples
787727

788728

789-
class WnliProcessor(DataProcessor):
729+
class WnliProcessor(DefaultGLUEDataProcessor):
790730
"""Processor for the WNLI data set (GLUE version)."""
791731

792-
def get_train_examples(self, data_dir):
793-
"""See base class."""
794-
return self._create_examples_tfds("train")
795-
796-
def get_dev_examples(self, data_dir):
797-
"""See base class."""
798-
return self._create_examples_tfds("validation")
799-
800-
def get_test_examples(self, data_dir):
801-
"""See base class."""
802-
return self._create_examples_tfds("test")
803-
804732
def get_labels(self):
805733
"""See base class."""
806734
return ["0", "1"]
@@ -1282,27 +1210,7 @@ def _create_examples(self, lines, set_type):
12821210
return examples
12831211

12841212

1285-
class SuperGLUEDataProcessor(DataProcessor):
1286-
"""Processor for the SuperGLUE dataset."""
1287-
1288-
def get_train_examples(self, data_dir):
1289-
"""See base class."""
1290-
return self._create_examples_tfds("train")
1291-
1292-
def get_dev_examples(self, data_dir):
1293-
"""See base class."""
1294-
return self._create_examples_tfds("validation")
1295-
1296-
def get_test_examples(self, data_dir):
1297-
"""See base class."""
1298-
return self._create_examples_tfds("test")
1299-
1300-
def _create_examples_tfds(self, set_type):
1301-
"""Creates examples for the training/dev/test sets."""
1302-
raise NotImplementedError()
1303-
1304-
1305-
class BoolQProcessor(SuperGLUEDataProcessor):
1213+
class BoolQProcessor(DefaultGLUEDataProcessor):
13061214
"""Processor for the BoolQ dataset (SuperGLUE diagnostics dataset)."""
13071215

13081216
def get_labels(self):
@@ -1331,7 +1239,7 @@ def _create_examples_tfds(self, set_type):
13311239
return examples
13321240

13331241

1334-
class CBProcessor(SuperGLUEDataProcessor):
1242+
class CBProcessor(DefaultGLUEDataProcessor):
13351243
"""Processor for the CB dataset (SuperGLUE diagnostics dataset)."""
13361244

13371245
def get_labels(self):
@@ -1360,7 +1268,7 @@ def _create_examples_tfds(self, set_type):
13601268
return examples
13611269

13621270

1363-
class SuperGLUERTEProcessor(SuperGLUEDataProcessor):
1271+
class SuperGLUERTEProcessor(DefaultGLUEDataProcessor):
13641272
"""Processor for the RTE dataset (SuperGLUE version)."""
13651273

13661274
def get_labels(self):
@@ -1396,7 +1304,8 @@ def file_based_convert_examples_to_features(examples,
13961304
max_seq_length,
13971305
tokenizer,
13981306
output_file,
1399-
label_type=None):
1307+
label_type=None,
1308+
featurize_fn=None):
14001309
"""Convert a set of `InputExample`s to a TFRecord file."""
14011310

14021311
tf.io.gfile.makedirs(os.path.dirname(output_file))
@@ -1406,8 +1315,12 @@ def file_based_convert_examples_to_features(examples,
14061315
if ex_index % 10000 == 0:
14071316
logging.info("Writing example %d of %d", ex_index, len(examples))
14081317

1409-
feature = convert_single_example(ex_index, example, label_list,
1410-
max_seq_length, tokenizer)
1318+
if featurize_fn:
1319+
feature = featurize_fn(ex_index, example, label_list, max_seq_length,
1320+
tokenizer)
1321+
else:
1322+
feature = convert_single_example(ex_index, example, label_list,
1323+
max_seq_length, tokenizer)
14111324

14121325
def create_int_feature(values):
14131326
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
@@ -1496,15 +1409,17 @@ def generate_tf_record_from_data_file(processor,
14961409
file_based_convert_examples_to_features(train_input_data_examples,
14971410
label_list, max_seq_length,
14981411
tokenizer, train_data_output_path,
1499-
label_type)
1412+
label_type,
1413+
processor.featurize_example)
15001414
num_training_data = len(train_input_data_examples)
15011415

15021416
if eval_data_output_path:
15031417
eval_input_data_examples = processor.get_dev_examples(data_dir)
15041418
file_based_convert_examples_to_features(eval_input_data_examples,
15051419
label_list, max_seq_length,
15061420
tokenizer, eval_data_output_path,
1507-
label_type)
1421+
label_type,
1422+
processor.featurize_example)
15081423

15091424
meta_data = {
15101425
"processor_type": processor.get_processor_name(),
@@ -1518,13 +1433,15 @@ def generate_tf_record_from_data_file(processor,
15181433
for language, examples in test_input_data_examples.items():
15191434
file_based_convert_examples_to_features(
15201435
examples, label_list, max_seq_length, tokenizer,
1521-
test_data_output_path.format(language), label_type)
1436+
test_data_output_path.format(language), label_type,
1437+
processor.featurize_example)
15221438
meta_data["test_{}_data_size".format(language)] = len(examples)
15231439
else:
15241440
file_based_convert_examples_to_features(test_input_data_examples,
15251441
label_list, max_seq_length,
15261442
tokenizer, test_data_output_path,
1527-
label_type)
1443+
label_type,
1444+
processor.featurize_example)
15281445
meta_data["test_data_size"] = len(test_input_data_examples)
15291446

15301447
if is_regression:

0 commit comments

Comments
 (0)