Skip to content

Commit 9c9aec1

Browse files
aichendoubletensorflower-gardener
authored andcommitted
Support to run ALBERT on SQuAD task.
PiperOrigin-RevId: 286637307
1 parent 553a4f4 commit 9c9aec1

File tree

4 files changed

+953
-35
lines changed

4 files changed

+953
-35
lines changed

official/nlp/bert/create_finetuning_data.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
import tensorflow as tf
2626

2727
from official.nlp.bert import classifier_data_lib
28-
from official.nlp.bert import squad_lib
28+
# word-piece tokenizer based squad_lib
29+
from official.nlp.bert import squad_lib as squad_lib_wp
30+
# sentence-piece tokenizer based squad_lib
31+
from official.nlp.bert import squad_lib_sp
2932

3033
FLAGS = flags.FLAGS
3134

@@ -70,14 +73,12 @@
7073
flags.DEFINE_string(
7174
"train_data_output_path", None,
7275
"The path in which generated training input data will be written as tf"
73-
" records."
74-
)
76+
" records.")
7577

7678
flags.DEFINE_string(
7779
"eval_data_output_path", None,
7880
"The path in which generated training input data will be written as tf"
79-
" records."
80-
)
81+
" records.")
8182

8283
flags.DEFINE_string("meta_data_file_path", None,
8384
"The path in which input meta data will be written.")
@@ -93,6 +94,15 @@
9394
"Sequences longer than this will be truncated, and sequences shorter "
9495
"than this will be padded.")
9596

97+
flags.DEFINE_string("sp_model_file", "",
98+
"The path to the model used by sentence piece tokenizer.")
99+
100+
flags.DEFINE_enum(
101+
"tokenizer_impl", "word_piece", ["word_piece", "sentence_piece"],
102+
"Specifies the tokenizer implementation, i.e., whehter to use word_piece "
103+
"or sentence_piece tokenizer. Canonical BERT uses word_piece tokenizer, "
104+
"while ALBERT uses sentence_piece tokenizer.")
105+
96106

97107
def generate_classifier_dataset():
98108
"""Generates classifier dataset and returns input meta data."""
@@ -124,13 +134,30 @@ def generate_classifier_dataset():
124134
def generate_squad_dataset():
125135
"""Generates squad training dataset and returns input meta data."""
126136
assert FLAGS.squad_data_file
127-
return squad_lib.generate_tf_record_from_json_file(
128-
FLAGS.squad_data_file, FLAGS.vocab_file, FLAGS.train_data_output_path,
129-
FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length,
130-
FLAGS.doc_stride, FLAGS.version_2_with_negative)
137+
if FLAGS.tokenizer_impl == "word_piece":
138+
return squad_lib_wp.generate_tf_record_from_json_file(
139+
FLAGS.squad_data_file, FLAGS.vocab_file, FLAGS.train_data_output_path,
140+
FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length,
141+
FLAGS.doc_stride, FLAGS.version_2_with_negative)
142+
else:
143+
assert FLAGS.tokenizer_impl == "sentence_piece"
144+
return squad_lib_sp.generate_tf_record_from_json_file(
145+
FLAGS.squad_data_file, FLAGS.sp_model_file,
146+
FLAGS.train_data_output_path, FLAGS.max_seq_length, FLAGS.do_lower_case,
147+
FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative)
131148

132149

133150
def main(_):
151+
if FLAGS.tokenizer_impl == "word_piece":
152+
if not FLAGS.vocab_file:
153+
raise ValueError(
154+
"FLAG vocab_file for word-piece tokenizer is not specified.")
155+
else:
156+
assert FLAGS.tokenizer_impl == "sentence_piece"
157+
if not FLAGS.sp_model_file:
158+
raise ValueError(
159+
"FLAG sp_model_file for sentence-piece tokenizer is not specified.")
160+
134161
if FLAGS.fine_tuning_task_type == "classification":
135162
input_meta_data = generate_classifier_dataset()
136163
else:
@@ -141,7 +168,6 @@ def main(_):
141168

142169

143170
if __name__ == "__main__":
144-
flags.mark_flag_as_required("vocab_file")
145171
flags.mark_flag_as_required("train_data_output_path")
146172
flags.mark_flag_as_required("meta_data_file_path")
147173
app.run(main)

official/nlp/bert/run_squad.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434
from official.nlp.bert import common_flags
3535
from official.nlp.bert import input_pipeline
3636
from official.nlp.bert import model_saving_utils
37-
from official.nlp.bert import squad_lib
37+
# word-piece tokenizer based squad_lib
38+
from official.nlp.bert import squad_lib as squad_lib_wp
39+
# sentence-piece tokenizer based squad_lib
40+
from official.nlp.bert import squad_lib_sp
3841
from official.nlp.bert import tokenization
3942
from official.utils.misc import distribution_utils
4043
from official.utils.misc import keras_utils
@@ -80,11 +83,22 @@
8083
'max_answer_length', 30,
8184
'The maximum length of an answer that can be generated. This is needed '
8285
'because the start and end predictions are not conditioned on one another.')
86+
flags.DEFINE_string(
87+
'sp_model_file', None,
88+
'The path to the sentence piece model. Used by sentence piece tokenizer '
89+
'employed by ALBERT.')
90+
8391

8492
common_flags.define_common_bert_flags()
8593

8694
FLAGS = flags.FLAGS
8795

96+
MODEL_CLASSES = {
97+
'bert': (modeling.BertConfig, squad_lib_wp, tokenization.FullTokenizer),
98+
'albert': (modeling.AlbertConfig, squad_lib_sp,
99+
tokenization.FullSentencePieceTokenizer),
100+
}
101+
88102

89103
def squad_loss_fn(start_positions,
90104
end_positions,
@@ -121,6 +135,7 @@ def _loss_fn(labels, model_outputs):
121135

122136
def get_raw_results(predictions):
123137
"""Converts multi-replica predictions to RawResult."""
138+
squad_lib = MODEL_CLASSES[FLAGS.model_type][1]
124139
for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
125140
predictions['start_logits'],
126141
predictions['end_logits']):
@@ -167,9 +182,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
167182
# Prediction always uses float32, even if training uses mixed precision.
168183
tf.keras.mixed_precision.experimental.set_policy('float32')
169184
squad_model, _ = bert_models.squad_model(
170-
bert_config,
171-
input_meta_data['max_seq_length'],
172-
float_type=tf.float32)
185+
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32)
173186

174187
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
175188
logging.info('Restoring checkpoints from %s', checkpoint_path)
@@ -219,7 +232,8 @@ def train_squad(strategy,
219232
if use_float16:
220233
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
221234

222-
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
235+
bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file(
236+
FLAGS.bert_config_file)
223237
epochs = FLAGS.num_train_epochs
224238
num_train_examples = input_meta_data['train_data_size']
225239
max_seq_length = input_meta_data['max_seq_length']
@@ -281,7 +295,14 @@ def _get_squad_model():
281295

282296
def predict_squad(strategy, input_meta_data):
283297
"""Makes predictions for a squad dataset."""
284-
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
298+
config_cls, squad_lib, tokenizer_cls = MODEL_CLASSES[FLAGS.model_type]
299+
bert_config = config_cls.from_json_file(FLAGS.bert_config_file)
300+
if tokenizer_cls == tokenization.FullTokenizer:
301+
tokenizer = tokenizer_cls(
302+
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
303+
else:
304+
assert tokenizer_cls == tokenization.FullSentencePieceTokenizer
305+
tokenizer = tokenizer_cls(sp_model_file=FLAGS.sp_model_file)
285306
doc_stride = input_meta_data['doc_stride']
286307
max_query_length = input_meta_data['max_query_length']
287308
# Whether data should be in Ver 2.0 format.
@@ -292,9 +313,6 @@ def predict_squad(strategy, input_meta_data):
292313
is_training=False,
293314
version_2_with_negative=version_2_with_negative)
294315

295-
tokenizer = tokenization.FullTokenizer(
296-
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
297-
298316
eval_writer = squad_lib.FeatureWriter(
299317
filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
300318
is_training=False)
@@ -309,7 +327,7 @@ def _append_feature(feature, is_padding):
309327
# of examples must be a multiple of the batch size, or else examples
310328
# will get dropped. So we pad with fake examples which are ignored
311329
# later on.
312-
dataset_size = squad_lib.convert_examples_to_features(
330+
kwargs = dict(
313331
examples=eval_examples,
314332
tokenizer=tokenizer,
315333
max_seq_length=input_meta_data['max_seq_length'],
@@ -318,6 +336,11 @@ def _append_feature(feature, is_padding):
318336
is_training=False,
319337
output_fn=_append_feature,
320338
batch_size=FLAGS.predict_batch_size)
339+
340+
# squad_lib_sp requires one more argument 'do_lower_case'.
341+
if squad_lib == squad_lib_sp:
342+
kwargs['do_lower_case'] = FLAGS.do_lower_case
343+
dataset_size = squad_lib.convert_examples_to_features(**kwargs)
321344
eval_writer.close()
322345

323346
logging.info('***** Running predictions *****')
@@ -358,12 +381,10 @@ def export_squad(model_export_path, input_meta_data):
358381
"""
359382
if not model_export_path:
360383
raise ValueError('Export path is not specified: %s' % model_export_path)
361-
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
362-
384+
bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file(
385+
FLAGS.bert_config_file)
363386
squad_model, _ = bert_models.squad_model(
364-
bert_config,
365-
input_meta_data['max_seq_length'],
366-
float_type=tf.float32)
387+
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32)
367388
model_saving_utils.export_bert_model(
368389
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
369390

0 commit comments

Comments
 (0)