@@ -129,6 +129,30 @@ def _read_jsonl(cls, input_file):
129
129
lines .append (json .loads (json_str ))
130
130
return lines
131
131
132
+ def featurize_example (self , * kargs , ** kwargs ):
133
+ """Converts a single `InputExample` into a single `InputFeatures`."""
134
+ return convert_single_example (* kargs , ** kwargs )
135
+
136
+
137
+ class DefaultGLUEDataProcessor (DataProcessor ):
138
+ """Processor for the SuperGLUE dataset."""
139
+
140
+ def get_train_examples (self , data_dir ):
141
+ """See base class."""
142
+ return self ._create_examples_tfds ("train" )
143
+
144
+ def get_dev_examples (self , data_dir ):
145
+ """See base class."""
146
+ return self ._create_examples_tfds ("validation" )
147
+
148
+ def get_test_examples (self , data_dir ):
149
+ """See base class."""
150
+ return self ._create_examples_tfds ("test" )
151
+
152
+ def _create_examples_tfds (self , set_type ):
153
+ """Creates examples for the training/dev/test sets."""
154
+ raise NotImplementedError ()
155
+
132
156
133
157
class AxProcessor (DataProcessor ):
134
158
"""Processor for the AX dataset (GLUE diagnostics dataset)."""
@@ -178,21 +202,9 @@ def _create_examples_tfds(self, dataset, set_type):
178
202
return examples
179
203
180
204
181
- class ColaProcessor (DataProcessor ):
205
+ class ColaProcessor (DefaultGLUEDataProcessor ):
182
206
"""Processor for the CoLA data set (GLUE version)."""
183
207
184
- def get_train_examples (self , data_dir ):
185
- """See base class."""
186
- return self ._create_examples_tfds ("train" )
187
-
188
- def get_dev_examples (self , data_dir ):
189
- """See base class."""
190
- return self ._create_examples_tfds ("validation" )
191
-
192
- def get_test_examples (self , data_dir ):
193
- """See base class."""
194
- return self ._create_examples_tfds ("test" )
195
-
196
208
def get_labels (self ):
197
209
"""See base class."""
198
210
return ["0" , "1" ]
@@ -315,21 +327,9 @@ def _create_examples_tfds(self, set_type):
315
327
return examples
316
328
317
329
318
- class MrpcProcessor (DataProcessor ):
330
+ class MrpcProcessor (DefaultGLUEDataProcessor ):
319
331
"""Processor for the MRPC data set (GLUE version)."""
320
332
321
- def get_train_examples (self , data_dir ):
322
- """See base class."""
323
- return self ._create_examples_tfds ("train" )
324
-
325
- def get_dev_examples (self , data_dir ):
326
- """See base class."""
327
- return self ._create_examples_tfds ("validation" )
328
-
329
- def get_test_examples (self , data_dir ):
330
- """See base class."""
331
- return self ._create_examples_tfds ("test" )
332
-
333
333
def get_labels (self ):
334
334
"""See base class."""
335
335
return ["0" , "1" ]
@@ -437,21 +437,9 @@ def get_processor_name():
437
437
return "XTREME-PAWS-X"
438
438
439
439
440
- class QnliProcessor (DataProcessor ):
440
+ class QnliProcessor (DefaultGLUEDataProcessor ):
441
441
"""Processor for the QNLI data set (GLUE version)."""
442
442
443
- def get_train_examples (self , data_dir ):
444
- """See base class."""
445
- return self ._create_examples_tfds ("train" )
446
-
447
- def get_dev_examples (self , data_dir ):
448
- """See base class."""
449
- return self ._create_examples_tfds ("validation" )
450
-
451
- def get_test_examples (self , data_dir ):
452
- """See base class."""
453
- return self ._create_examples_tfds ("test" )
454
-
455
443
def get_labels (self ):
456
444
"""See base class."""
457
445
return ["entailment" , "not_entailment" ]
@@ -480,21 +468,9 @@ def _create_examples_tfds(self, set_type):
480
468
return examples
481
469
482
470
483
- class QqpProcessor (DataProcessor ):
471
+ class QqpProcessor (DefaultGLUEDataProcessor ):
484
472
"""Processor for the QQP data set (GLUE version)."""
485
473
486
- def get_train_examples (self , data_dir ):
487
- """See base class."""
488
- return self ._create_examples_tfds ("train" )
489
-
490
- def get_dev_examples (self , data_dir ):
491
- """See base class."""
492
- return self ._create_examples_tfds ("validation" )
493
-
494
- def get_test_examples (self , data_dir ):
495
- """See base class."""
496
- return self ._create_examples_tfds ("test" )
497
-
498
474
def get_labels (self ):
499
475
"""See base class."""
500
476
return ["0" , "1" ]
@@ -523,21 +499,9 @@ def _create_examples_tfds(self, set_type):
523
499
return examples
524
500
525
501
526
- class RteProcessor (DataProcessor ):
502
+ class RteProcessor (DefaultGLUEDataProcessor ):
527
503
"""Processor for the RTE data set (GLUE version)."""
528
504
529
- def get_train_examples (self , data_dir ):
530
- """See base class."""
531
- return self ._create_examples_tfds ("train" )
532
-
533
- def get_dev_examples (self , data_dir ):
534
- """See base class."""
535
- return self ._create_examples_tfds ("validation" )
536
-
537
- def get_test_examples (self , data_dir ):
538
- """See base class."""
539
- return self ._create_examples_tfds ("test" )
540
-
541
505
def get_labels (self ):
542
506
"""See base class."""
543
507
# All datasets are converted to 2-class split, where for 3-class datasets we
@@ -568,21 +532,9 @@ def _create_examples_tfds(self, set_type):
568
532
return examples
569
533
570
534
571
- class SstProcessor (DataProcessor ):
535
+ class SstProcessor (DefaultGLUEDataProcessor ):
572
536
"""Processor for the SST-2 data set (GLUE version)."""
573
537
574
- def get_train_examples (self , data_dir ):
575
- """See base class."""
576
- return self ._create_examples_tfds ("train" )
577
-
578
- def get_dev_examples (self , data_dir ):
579
- """See base class."""
580
- return self ._create_examples_tfds ("validation" )
581
-
582
- def get_test_examples (self , data_dir ):
583
- """See base class."""
584
- return self ._create_examples_tfds ("test" )
585
-
586
538
def get_labels (self ):
587
539
"""See base class."""
588
540
return ["0" , "1" ]
@@ -609,7 +561,7 @@ def _create_examples_tfds(self, set_type):
609
561
return examples
610
562
611
563
612
- class StsBProcessor (DataProcessor ):
564
+ class StsBProcessor (DefaultGLUEDataProcessor ):
613
565
"""Processor for the STS-B data set (GLUE version)."""
614
566
615
567
def __init__ (self , process_text_fn = tokenization .convert_to_unicode ):
@@ -618,18 +570,6 @@ def __init__(self, process_text_fn=tokenization.convert_to_unicode):
618
570
self .label_type = float
619
571
self ._labels = None
620
572
621
- def get_train_examples (self , data_dir ):
622
- """See base class."""
623
- return self ._create_examples_tfds ("train" )
624
-
625
- def get_dev_examples (self , data_dir ):
626
- """See base class."""
627
- return self ._create_examples_tfds ("validation" )
628
-
629
- def get_test_examples (self , data_dir ):
630
- """See base class."""
631
- return self ._create_examples_tfds ("test" )
632
-
633
573
def _create_examples_tfds (self , set_type ):
634
574
"""Creates examples for the training/dev/test sets."""
635
575
dataset = tfds .load (
@@ -786,21 +726,9 @@ def _create_examples(self, split_name, set_type):
786
726
return examples
787
727
788
728
789
- class WnliProcessor (DataProcessor ):
729
+ class WnliProcessor (DefaultGLUEDataProcessor ):
790
730
"""Processor for the WNLI data set (GLUE version)."""
791
731
792
- def get_train_examples (self , data_dir ):
793
- """See base class."""
794
- return self ._create_examples_tfds ("train" )
795
-
796
- def get_dev_examples (self , data_dir ):
797
- """See base class."""
798
- return self ._create_examples_tfds ("validation" )
799
-
800
- def get_test_examples (self , data_dir ):
801
- """See base class."""
802
- return self ._create_examples_tfds ("test" )
803
-
804
732
def get_labels (self ):
805
733
"""See base class."""
806
734
return ["0" , "1" ]
@@ -1282,27 +1210,7 @@ def _create_examples(self, lines, set_type):
1282
1210
return examples
1283
1211
1284
1212
1285
- class SuperGLUEDataProcessor (DataProcessor ):
1286
- """Processor for the SuperGLUE dataset."""
1287
-
1288
- def get_train_examples (self , data_dir ):
1289
- """See base class."""
1290
- return self ._create_examples_tfds ("train" )
1291
-
1292
- def get_dev_examples (self , data_dir ):
1293
- """See base class."""
1294
- return self ._create_examples_tfds ("validation" )
1295
-
1296
- def get_test_examples (self , data_dir ):
1297
- """See base class."""
1298
- return self ._create_examples_tfds ("test" )
1299
-
1300
- def _create_examples_tfds (self , set_type ):
1301
- """Creates examples for the training/dev/test sets."""
1302
- raise NotImplementedError ()
1303
-
1304
-
1305
- class BoolQProcessor (SuperGLUEDataProcessor ):
1213
+ class BoolQProcessor (DefaultGLUEDataProcessor ):
1306
1214
"""Processor for the BoolQ dataset (SuperGLUE diagnostics dataset)."""
1307
1215
1308
1216
def get_labels (self ):
@@ -1331,7 +1239,7 @@ def _create_examples_tfds(self, set_type):
1331
1239
return examples
1332
1240
1333
1241
1334
- class CBProcessor (SuperGLUEDataProcessor ):
1242
+ class CBProcessor (DefaultGLUEDataProcessor ):
1335
1243
"""Processor for the CB dataset (SuperGLUE diagnostics dataset)."""
1336
1244
1337
1245
def get_labels (self ):
@@ -1360,7 +1268,7 @@ def _create_examples_tfds(self, set_type):
1360
1268
return examples
1361
1269
1362
1270
1363
- class SuperGLUERTEProcessor (SuperGLUEDataProcessor ):
1271
+ class SuperGLUERTEProcessor (DefaultGLUEDataProcessor ):
1364
1272
"""Processor for the RTE dataset (SuperGLUE version)."""
1365
1273
1366
1274
def get_labels (self ):
@@ -1396,7 +1304,8 @@ def file_based_convert_examples_to_features(examples,
1396
1304
max_seq_length ,
1397
1305
tokenizer ,
1398
1306
output_file ,
1399
- label_type = None ):
1307
+ label_type = None ,
1308
+ featurize_fn = None ):
1400
1309
"""Convert a set of `InputExample`s to a TFRecord file."""
1401
1310
1402
1311
tf .io .gfile .makedirs (os .path .dirname (output_file ))
@@ -1406,8 +1315,12 @@ def file_based_convert_examples_to_features(examples,
1406
1315
if ex_index % 10000 == 0 :
1407
1316
logging .info ("Writing example %d of %d" , ex_index , len (examples ))
1408
1317
1409
- feature = convert_single_example (ex_index , example , label_list ,
1410
- max_seq_length , tokenizer )
1318
+ if featurize_fn :
1319
+ feature = featurize_fn (ex_index , example , label_list , max_seq_length ,
1320
+ tokenizer )
1321
+ else :
1322
+ feature = convert_single_example (ex_index , example , label_list ,
1323
+ max_seq_length , tokenizer )
1411
1324
1412
1325
def create_int_feature (values ):
1413
1326
f = tf .train .Feature (int64_list = tf .train .Int64List (value = list (values )))
@@ -1496,15 +1409,17 @@ def generate_tf_record_from_data_file(processor,
1496
1409
file_based_convert_examples_to_features (train_input_data_examples ,
1497
1410
label_list , max_seq_length ,
1498
1411
tokenizer , train_data_output_path ,
1499
- label_type )
1412
+ label_type ,
1413
+ processor .featurize_example )
1500
1414
num_training_data = len (train_input_data_examples )
1501
1415
1502
1416
if eval_data_output_path :
1503
1417
eval_input_data_examples = processor .get_dev_examples (data_dir )
1504
1418
file_based_convert_examples_to_features (eval_input_data_examples ,
1505
1419
label_list , max_seq_length ,
1506
1420
tokenizer , eval_data_output_path ,
1507
- label_type )
1421
+ label_type ,
1422
+ processor .featurize_example )
1508
1423
1509
1424
meta_data = {
1510
1425
"processor_type" : processor .get_processor_name (),
@@ -1518,13 +1433,15 @@ def generate_tf_record_from_data_file(processor,
1518
1433
for language , examples in test_input_data_examples .items ():
1519
1434
file_based_convert_examples_to_features (
1520
1435
examples , label_list , max_seq_length , tokenizer ,
1521
- test_data_output_path .format (language ), label_type )
1436
+ test_data_output_path .format (language ), label_type ,
1437
+ processor .featurize_example )
1522
1438
meta_data ["test_{}_data_size" .format (language )] = len (examples )
1523
1439
else :
1524
1440
file_based_convert_examples_to_features (test_input_data_examples ,
1525
1441
label_list , max_seq_length ,
1526
1442
tokenizer , test_data_output_path ,
1527
- label_type )
1443
+ label_type ,
1444
+ processor .featurize_example )
1528
1445
meta_data ["test_data_size" ] = len (test_input_data_examples )
1529
1446
1530
1447
if is_regression :
0 commit comments