@@ -1299,6 +1299,139 @@ def _create_examples_tfds(self, set_type):
1299
1299
return examples
1300
1300
1301
1301
1302
+ class WiCInputExample (InputExample ):
1303
+ """Processor for the WiC dataset (SuperGLUE version)."""
1304
+
1305
+ def __init__ (self ,
1306
+ guid ,
1307
+ text_a ,
1308
+ text_b = None ,
1309
+ label = None ,
1310
+ word = None ,
1311
+ weight = None ,
1312
+ example_id = None ):
1313
+ """A single training/test example for simple seq regression/classification."""
1314
+ super (WiCInputExample , self ).__init__ (guid , text_a , text_b , label , weight ,
1315
+ example_id )
1316
+ self .word = word
1317
+
1318
+
1319
+ class WiCProcessor (DefaultGLUEDataProcessor ):
1320
+ """Processor for the RTE dataset (SuperGLUE version)."""
1321
+
1322
+ def get_labels (self ):
1323
+ """Not used."""
1324
+ return []
1325
+
1326
+ @staticmethod
1327
+ def get_processor_name ():
1328
+ """See base class."""
1329
+ return "RTESuperGLUE"
1330
+
1331
+ def _create_examples_tfds (self , set_type ):
1332
+ """Creates examples for the training/dev/test sets."""
1333
+ examples = []
1334
+ dataset = tfds .load (
1335
+ "super_glue/wic" , split = set_type , try_gcs = True ).as_numpy_iterator ()
1336
+ for example in dataset :
1337
+ guid = "%s-%s" % (set_type , self .process_text_fn (str (example ["idx" ])))
1338
+ text_a = self .process_text_fn (example ["sentence1" ])
1339
+ text_b = self .process_text_fn (example ["sentence2" ])
1340
+ word = self .process_text_fn (example ["word" ])
1341
+ label = 0
1342
+ if set_type != "test" :
1343
+ label = example ["label" ]
1344
+ examples .append (
1345
+ WiCInputExample (
1346
+ guid = guid , text_a = text_a , text_b = text_b , word = word , label = label ))
1347
+ return examples
1348
+
1349
+ def featurize_example (self , ex_index , example , label_list , max_seq_length ,
1350
+ tokenizer ):
1351
+ """Here we concate sentence1, sentence2, word together with [SEP] tokens."""
1352
+ del label_list
1353
+ tokens_a = tokenizer .tokenize (example .text_a )
1354
+ tokens_b = tokenizer .tokenize (example .text_b )
1355
+ tokens_word = tokenizer .tokenize (example .word )
1356
+
1357
+ # Modifies `tokens_a` and `tokens_b` in place so that the total
1358
+ # length is less than the specified length.
1359
+ # Account for [CLS], [SEP], [SEP], [SEP] with "- 4"
1360
+ # Here we only pop out the first two sentence tokens.
1361
+ _truncate_seq_pair (tokens_a , tokens_b ,
1362
+ max_seq_length - 4 - len (tokens_word ))
1363
+
1364
+ seg_id_a = 0
1365
+ seg_id_b = 1
1366
+ seg_id_c = 2
1367
+ seg_id_cls = 0
1368
+ seg_id_pad = 0
1369
+
1370
+ tokens = []
1371
+ segment_ids = []
1372
+ tokens .append ("[CLS]" )
1373
+ segment_ids .append (seg_id_cls )
1374
+ for token in tokens_a :
1375
+ tokens .append (token )
1376
+ segment_ids .append (seg_id_a )
1377
+ tokens .append ("[SEP]" )
1378
+ segment_ids .append (seg_id_a )
1379
+
1380
+ for token in tokens_b :
1381
+ tokens .append (token )
1382
+ segment_ids .append (seg_id_b )
1383
+
1384
+ tokens .append ("[SEP]" )
1385
+ segment_ids .append (seg_id_b )
1386
+
1387
+ for token in tokens_word :
1388
+ tokens .append (token )
1389
+ segment_ids .append (seg_id_c )
1390
+
1391
+ tokens .append ("[SEP]" )
1392
+ segment_ids .append (seg_id_c )
1393
+
1394
+ input_ids = tokenizer .convert_tokens_to_ids (tokens )
1395
+
1396
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
1397
+ # tokens are attended to.
1398
+ input_mask = [1 ] * len (input_ids )
1399
+
1400
+ # Zero-pad up to the sequence length.
1401
+ while len (input_ids ) < max_seq_length :
1402
+ input_ids .append (0 )
1403
+ input_mask .append (0 )
1404
+ segment_ids .append (seg_id_pad )
1405
+
1406
+ assert len (input_ids ) == max_seq_length
1407
+ assert len (input_mask ) == max_seq_length
1408
+ assert len (segment_ids ) == max_seq_length
1409
+
1410
+ label_id = example .label
1411
+ if ex_index < 5 :
1412
+ logging .info ("*** Example ***" )
1413
+ logging .info ("guid: %s" , (example .guid ))
1414
+ logging .info ("tokens: %s" ,
1415
+ " " .join ([tokenization .printable_text (x ) for x in tokens ]))
1416
+ logging .info ("input_ids: %s" , " " .join ([str (x ) for x in input_ids ]))
1417
+ logging .info ("input_mask: %s" , " " .join ([str (x ) for x in input_mask ]))
1418
+ logging .info ("segment_ids: %s" , " " .join ([str (x ) for x in segment_ids ]))
1419
+ logging .info ("label: %s (id = %s)" , example .label , str (label_id ))
1420
+ logging .info ("weight: %s" , example .weight )
1421
+ logging .info ("example_id: %s" , example .example_id )
1422
+
1423
+ feature = InputFeatures (
1424
+ input_ids = input_ids ,
1425
+ input_mask = input_mask ,
1426
+ segment_ids = segment_ids ,
1427
+ label_id = label_id ,
1428
+ is_real_example = True ,
1429
+ weight = example .weight ,
1430
+ example_id = example .example_id )
1431
+
1432
+ return feature
1433
+
1434
+
1302
1435
def file_based_convert_examples_to_features (examples ,
1303
1436
label_list ,
1304
1437
max_seq_length ,
0 commit comments