@@ -68,6 +68,9 @@ def __init__(self,
68
68
class DataProcessor (object ):
69
69
"""Base class for data converters for sequence classification data sets."""
70
70
71
+ def __init__ (self , process_text_fn = tokenization .convert_to_unicode ):
72
+ self .process_text_fn = process_text_fn
73
+
71
74
def get_train_examples (self , data_dir ):
72
75
"""Gets a collection of `InputExample`s for the train set."""
73
76
raise NotImplementedError ()
@@ -103,7 +106,8 @@ def _read_tsv(cls, input_file, quotechar=None):
103
106
class XnliProcessor (DataProcessor ):
104
107
"""Processor for the XNLI data set."""
105
108
106
- def __init__ (self ):
109
+ def __init__ (self , process_text_fn = tokenization .convert_to_unicode ):
110
+ super (XnliProcessor , self ).__init__ (process_text_fn )
107
111
self .language = "zh"
108
112
109
113
def get_train_examples (self , data_dir ):
@@ -116,11 +120,11 @@ def get_train_examples(self, data_dir):
116
120
if i == 0 :
117
121
continue
118
122
guid = "train-%d" % (i )
119
- text_a = tokenization . convert_to_unicode (line [0 ])
120
- text_b = tokenization . convert_to_unicode (line [1 ])
121
- label = tokenization . convert_to_unicode (line [2 ])
122
- if label == tokenization . convert_to_unicode ("contradictory" ):
123
- label = tokenization . convert_to_unicode ("contradiction" )
123
+ text_a = self . process_text_fn (line [0 ])
124
+ text_b = self . process_text_fn (line [1 ])
125
+ label = self . process_text_fn (line [2 ])
126
+ if label == self . process_text_fn ("contradictory" ):
127
+ label = self . process_text_fn ("contradiction" )
124
128
examples .append (
125
129
InputExample (guid = guid , text_a = text_a , text_b = text_b , label = label ))
126
130
return examples
@@ -133,12 +137,12 @@ def get_dev_examples(self, data_dir):
133
137
if i == 0 :
134
138
continue
135
139
guid = "dev-%d" % (i )
136
- language = tokenization . convert_to_unicode (line [0 ])
137
- if language != tokenization . convert_to_unicode (self .language ):
140
+ language = self . process_text_fn (line [0 ])
141
+ if language != self . process_text_fn (self .language ):
138
142
continue
139
- text_a = tokenization . convert_to_unicode (line [6 ])
140
- text_b = tokenization . convert_to_unicode (line [7 ])
141
- label = tokenization . convert_to_unicode (line [1 ])
143
+ text_a = self . process_text_fn (line [6 ])
144
+ text_b = self . process_text_fn (line [7 ])
145
+ label = self . process_text_fn (line [1 ])
142
146
examples .append (
143
147
InputExample (guid = guid , text_a = text_a , text_b = text_b , label = label ))
144
148
return examples
@@ -187,13 +191,13 @@ def _create_examples(self, lines, set_type):
187
191
for (i , line ) in enumerate (lines ):
188
192
if i == 0 :
189
193
continue
190
- guid = "%s-%s" % (set_type , tokenization . convert_to_unicode (line [0 ]))
191
- text_a = tokenization . convert_to_unicode (line [8 ])
192
- text_b = tokenization . convert_to_unicode (line [9 ])
194
+ guid = "%s-%s" % (set_type , self . process_text_fn (line [0 ]))
195
+ text_a = self . process_text_fn (line [8 ])
196
+ text_b = self . process_text_fn (line [9 ])
193
197
if set_type == "test" :
194
198
label = "contradiction"
195
199
else :
196
- label = tokenization . convert_to_unicode (line [- 1 ])
200
+ label = self . process_text_fn (line [- 1 ])
197
201
examples .append (
198
202
InputExample (guid = guid , text_a = text_a , text_b = text_b , label = label ))
199
203
return examples
@@ -233,12 +237,12 @@ def _create_examples(self, lines, set_type):
233
237
if i == 0 :
234
238
continue
235
239
guid = "%s-%s" % (set_type , i )
236
- text_a = tokenization . convert_to_unicode (line [3 ])
237
- text_b = tokenization . convert_to_unicode (line [4 ])
240
+ text_a = self . process_text_fn (line [3 ])
241
+ text_b = self . process_text_fn (line [4 ])
238
242
if set_type == "test" :
239
243
label = "0"
240
244
else :
241
- label = tokenization . convert_to_unicode (line [0 ])
245
+ label = self . process_text_fn (line [0 ])
242
246
examples .append (
243
247
InputExample (guid = guid , text_a = text_a , text_b = text_b , label = label ))
244
248
return examples
@@ -280,11 +284,11 @@ def _create_examples(self, lines, set_type):
280
284
continue
281
285
guid = "%s-%s" % (set_type , i )
282
286
if set_type == "test" :
283
- text_a = tokenization . convert_to_unicode (line [1 ])
287
+ text_a = self . process_text_fn (line [1 ])
284
288
label = "0"
285
289
else :
286
- text_a = tokenization . convert_to_unicode (line [3 ])
287
- label = tokenization . convert_to_unicode (line [1 ])
290
+ text_a = self . process_text_fn (line [3 ])
291
+ label = self . process_text_fn (line [1 ])
288
292
examples .append (
289
293
InputExample (guid = guid , text_a = text_a , text_b = None , label = label ))
290
294
return examples
@@ -525,35 +529,31 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
525
529
526
530
def generate_tf_record_from_data_file (processor ,
527
531
data_dir ,
528
- vocab_file ,
532
+ tokenizer ,
529
533
train_data_output_path = None ,
530
534
eval_data_output_path = None ,
531
- max_seq_length = 128 ,
532
- do_lower_case = True ):
535
+ max_seq_length = 128 ):
533
536
"""Generates and saves training data into a tf record file.
534
537
535
538
Arguments:
536
539
processor: Input processor object to be used for generating data. Subclass
537
540
of `DataProcessor`.
538
541
data_dir: Directory that contains train/eval data to process. Data files
539
542
should be in from "dev.tsv", "test.tsv", or "train.tsv".
540
- vocab_file: Text file with words to be used for training/evaluation .
543
+ tokenizer: The tokenizer to be applied on the data .
541
544
train_data_output_path: Output to which processed tf record for training
542
545
will be saved.
543
546
eval_data_output_path: Output to which processed tf record for evaluation
544
547
will be saved.
545
548
max_seq_length: Maximum sequence length of the to be generated
546
549
training/eval data.
547
- do_lower_case: Whether to lower case input text.
548
550
549
551
Returns:
550
552
A dictionary containing input meta data.
551
553
"""
552
554
assert train_data_output_path or eval_data_output_path
553
555
554
556
label_list = processor .get_labels ()
555
- tokenizer = tokenization .FullTokenizer (
556
- vocab_file = vocab_file , do_lower_case = do_lower_case )
557
557
assert train_data_output_path
558
558
train_input_data_examples = processor .get_train_examples (data_dir )
559
559
file_based_convert_examples_to_features (train_input_data_examples , label_list ,
0 commit comments