Skip to content

Commit c413c90

Browse files
aichendoubletensorflower-gardener
authored andcommitted
Support to create classification dataset using sentence piece tokenizer.
PiperOrigin-RevId: 286805889
1 parent 01a51ee commit c413c90

File tree

2 files changed

+43
-32
lines changed

2 files changed

+43
-32
lines changed

official/nlp/bert/classifier_data_lib.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def __init__(self,
6868
class DataProcessor(object):
6969
"""Base class for data converters for sequence classification data sets."""
7070

71+
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
72+
self.process_text_fn = process_text_fn
73+
7174
def get_train_examples(self, data_dir):
7275
"""Gets a collection of `InputExample`s for the train set."""
7376
raise NotImplementedError()
@@ -103,7 +106,8 @@ def _read_tsv(cls, input_file, quotechar=None):
103106
class XnliProcessor(DataProcessor):
104107
"""Processor for the XNLI data set."""
105108

106-
def __init__(self):
109+
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
110+
super(XnliProcessor, self).__init__(process_text_fn)
107111
self.language = "zh"
108112

109113
def get_train_examples(self, data_dir):
@@ -116,11 +120,11 @@ def get_train_examples(self, data_dir):
116120
if i == 0:
117121
continue
118122
guid = "train-%d" % (i)
119-
text_a = tokenization.convert_to_unicode(line[0])
120-
text_b = tokenization.convert_to_unicode(line[1])
121-
label = tokenization.convert_to_unicode(line[2])
122-
if label == tokenization.convert_to_unicode("contradictory"):
123-
label = tokenization.convert_to_unicode("contradiction")
123+
text_a = self.process_text_fn(line[0])
124+
text_b = self.process_text_fn(line[1])
125+
label = self.process_text_fn(line[2])
126+
if label == self.process_text_fn("contradictory"):
127+
label = self.process_text_fn("contradiction")
124128
examples.append(
125129
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
126130
return examples
@@ -133,12 +137,12 @@ def get_dev_examples(self, data_dir):
133137
if i == 0:
134138
continue
135139
guid = "dev-%d" % (i)
136-
language = tokenization.convert_to_unicode(line[0])
137-
if language != tokenization.convert_to_unicode(self.language):
140+
language = self.process_text_fn(line[0])
141+
if language != self.process_text_fn(self.language):
138142
continue
139-
text_a = tokenization.convert_to_unicode(line[6])
140-
text_b = tokenization.convert_to_unicode(line[7])
141-
label = tokenization.convert_to_unicode(line[1])
143+
text_a = self.process_text_fn(line[6])
144+
text_b = self.process_text_fn(line[7])
145+
label = self.process_text_fn(line[1])
142146
examples.append(
143147
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
144148
return examples
@@ -187,13 +191,13 @@ def _create_examples(self, lines, set_type):
187191
for (i, line) in enumerate(lines):
188192
if i == 0:
189193
continue
190-
guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0]))
191-
text_a = tokenization.convert_to_unicode(line[8])
192-
text_b = tokenization.convert_to_unicode(line[9])
194+
guid = "%s-%s" % (set_type, self.process_text_fn(line[0]))
195+
text_a = self.process_text_fn(line[8])
196+
text_b = self.process_text_fn(line[9])
193197
if set_type == "test":
194198
label = "contradiction"
195199
else:
196-
label = tokenization.convert_to_unicode(line[-1])
200+
label = self.process_text_fn(line[-1])
197201
examples.append(
198202
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
199203
return examples
@@ -233,12 +237,12 @@ def _create_examples(self, lines, set_type):
233237
if i == 0:
234238
continue
235239
guid = "%s-%s" % (set_type, i)
236-
text_a = tokenization.convert_to_unicode(line[3])
237-
text_b = tokenization.convert_to_unicode(line[4])
240+
text_a = self.process_text_fn(line[3])
241+
text_b = self.process_text_fn(line[4])
238242
if set_type == "test":
239243
label = "0"
240244
else:
241-
label = tokenization.convert_to_unicode(line[0])
245+
label = self.process_text_fn(line[0])
242246
examples.append(
243247
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
244248
return examples
@@ -280,11 +284,11 @@ def _create_examples(self, lines, set_type):
280284
continue
281285
guid = "%s-%s" % (set_type, i)
282286
if set_type == "test":
283-
text_a = tokenization.convert_to_unicode(line[1])
287+
text_a = self.process_text_fn(line[1])
284288
label = "0"
285289
else:
286-
text_a = tokenization.convert_to_unicode(line[3])
287-
label = tokenization.convert_to_unicode(line[1])
290+
text_a = self.process_text_fn(line[3])
291+
label = self.process_text_fn(line[1])
288292
examples.append(
289293
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
290294
return examples
@@ -525,35 +529,31 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
525529

526530
def generate_tf_record_from_data_file(processor,
527531
data_dir,
528-
vocab_file,
532+
tokenizer,
529533
train_data_output_path=None,
530534
eval_data_output_path=None,
531-
max_seq_length=128,
532-
do_lower_case=True):
535+
max_seq_length=128):
533536
"""Generates and saves training data into a tf record file.
534537
535538
Arguments:
536539
processor: Input processor object to be used for generating data. Subclass
537540
of `DataProcessor`.
538541
data_dir: Directory that contains train/eval data to process. Data files
539542
should be in from "dev.tsv", "test.tsv", or "train.tsv".
540-
vocab_file: Text file with words to be used for training/evaluation.
543+
tokenizer: The tokenizer to be applied on the data.
541544
train_data_output_path: Output to which processed tf record for training
542545
will be saved.
543546
eval_data_output_path: Output to which processed tf record for evaluation
544547
will be saved.
545548
max_seq_length: Maximum sequence length of the to be generated
546549
training/eval data.
547-
do_lower_case: Whether to lower case input text.
548550
549551
Returns:
550552
A dictionary containing input meta data.
551553
"""
552554
assert train_data_output_path or eval_data_output_path
553555

554556
label_list = processor.get_labels()
555-
tokenizer = tokenization.FullTokenizer(
556-
vocab_file=vocab_file, do_lower_case=do_lower_case)
557557
assert train_data_output_path
558558
train_input_data_examples = processor.get_train_examples(data_dir)
559559
file_based_convert_examples_to_features(train_input_data_examples, label_list,

official/nlp/bert/create_finetuning_data.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import functools
2122
import json
2223

2324
from absl import app
@@ -29,6 +30,7 @@
2930
from official.nlp.bert import squad_lib as squad_lib_wp
3031
# sentence-piece tokenizer based squad_lib
3132
from official.nlp.bert import squad_lib_sp
33+
from official.nlp.bert import tokenization
3234

3335
FLAGS = flags.FLAGS
3436

@@ -120,15 +122,24 @@ def generate_classifier_dataset():
120122
if task_name not in processors:
121123
raise ValueError("Task not found: %s" % (task_name))
122124

123-
processor = processors[task_name]()
125+
if FLAGS.tokenizer_impl == "word_piece":
126+
tokenizer = tokenization.FullTokenizer(
127+
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
128+
processor_text_fn = tokenization.convert_to_unicode
129+
else:
130+
assert FLAGS.tokenizer_impl == "sentence_piece"
131+
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
132+
processor_text_fn = functools.partial(
133+
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
134+
135+
processor = processors[task_name](processor_text_fn)
124136
return classifier_data_lib.generate_tf_record_from_data_file(
125137
processor,
126138
FLAGS.input_data_dir,
127-
FLAGS.vocab_file,
139+
tokenizer,
128140
train_data_output_path=FLAGS.train_data_output_path,
129141
eval_data_output_path=FLAGS.eval_data_output_path,
130-
max_seq_length=FLAGS.max_seq_length,
131-
do_lower_case=FLAGS.do_lower_case)
142+
max_seq_length=FLAGS.max_seq_length)
132143

133144

134145
def generate_squad_dataset():

0 commit comments

Comments
 (0)