34
34
from official .nlp .bert import common_flags
35
35
from official .nlp .bert import input_pipeline
36
36
from official .nlp .bert import model_saving_utils
37
- from official .nlp .bert import squad_lib
37
+ # word-piece tokenizer based squad_lib
38
+ from official .nlp .bert import squad_lib as squad_lib_wp
39
+ # sentence-piece tokenizer based squad_lib
40
+ from official .nlp .bert import squad_lib_sp
38
41
from official .nlp .bert import tokenization
39
42
from official .utils .misc import distribution_utils
40
43
from official .utils .misc import keras_utils
80
83
'max_answer_length' , 30 ,
81
84
'The maximum length of an answer that can be generated. This is needed '
82
85
'because the start and end predictions are not conditioned on one another.' )
86
+ flags .DEFINE_string (
87
+ 'sp_model_file' , None ,
88
+ 'The path to the sentence piece model. Used by sentence piece tokenizer '
89
+ 'employed by ALBERT.' )
90
+
83
91
84
92
common_flags .define_common_bert_flags ()
85
93
86
94
FLAGS = flags .FLAGS
87
95
96
+ MODEL_CLASSES = {
97
+ 'bert' : (modeling .BertConfig , squad_lib_wp , tokenization .FullTokenizer ),
98
+ 'albert' : (modeling .AlbertConfig , squad_lib_sp ,
99
+ tokenization .FullSentencePieceTokenizer ),
100
+ }
101
+
88
102
89
103
def squad_loss_fn (start_positions ,
90
104
end_positions ,
@@ -121,6 +135,7 @@ def _loss_fn(labels, model_outputs):
121
135
122
136
def get_raw_results (predictions ):
123
137
"""Converts multi-replica predictions to RawResult."""
138
+ squad_lib = MODEL_CLASSES [FLAGS .model_type ][1 ]
124
139
for unique_ids , start_logits , end_logits in zip (predictions ['unique_ids' ],
125
140
predictions ['start_logits' ],
126
141
predictions ['end_logits' ]):
@@ -167,9 +182,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
167
182
# Prediction always uses float32, even if training uses mixed precision.
168
183
tf .keras .mixed_precision .experimental .set_policy ('float32' )
169
184
squad_model , _ = bert_models .squad_model (
170
- bert_config ,
171
- input_meta_data ['max_seq_length' ],
172
- float_type = tf .float32 )
185
+ bert_config , input_meta_data ['max_seq_length' ], float_type = tf .float32 )
173
186
174
187
checkpoint_path = tf .train .latest_checkpoint (FLAGS .model_dir )
175
188
logging .info ('Restoring checkpoints from %s' , checkpoint_path )
@@ -219,7 +232,8 @@ def train_squad(strategy,
219
232
if use_float16 :
220
233
tf .keras .mixed_precision .experimental .set_policy ('mixed_float16' )
221
234
222
- bert_config = modeling .BertConfig .from_json_file (FLAGS .bert_config_file )
235
+ bert_config = MODEL_CLASSES [FLAGS .model_type ][0 ].from_json_file (
236
+ FLAGS .bert_config_file )
223
237
epochs = FLAGS .num_train_epochs
224
238
num_train_examples = input_meta_data ['train_data_size' ]
225
239
max_seq_length = input_meta_data ['max_seq_length' ]
@@ -281,7 +295,14 @@ def _get_squad_model():
281
295
282
296
def predict_squad (strategy , input_meta_data ):
283
297
"""Makes predictions for a squad dataset."""
284
- bert_config = modeling .BertConfig .from_json_file (FLAGS .bert_config_file )
298
+ config_cls , squad_lib , tokenizer_cls = MODEL_CLASSES [FLAGS .model_type ]
299
+ bert_config = config_cls .from_json_file (FLAGS .bert_config_file )
300
+ if tokenizer_cls == tokenization .FullTokenizer :
301
+ tokenizer = tokenizer_cls (
302
+ vocab_file = FLAGS .vocab_file , do_lower_case = FLAGS .do_lower_case )
303
+ else :
304
+ assert tokenizer_cls == tokenization .FullSentencePieceTokenizer
305
+ tokenizer = tokenizer_cls (sp_model_file = FLAGS .sp_model_file )
285
306
doc_stride = input_meta_data ['doc_stride' ]
286
307
max_query_length = input_meta_data ['max_query_length' ]
287
308
# Whether data should be in Ver 2.0 format.
@@ -292,9 +313,6 @@ def predict_squad(strategy, input_meta_data):
292
313
is_training = False ,
293
314
version_2_with_negative = version_2_with_negative )
294
315
295
- tokenizer = tokenization .FullTokenizer (
296
- vocab_file = FLAGS .vocab_file , do_lower_case = FLAGS .do_lower_case )
297
-
298
316
eval_writer = squad_lib .FeatureWriter (
299
317
filename = os .path .join (FLAGS .model_dir , 'eval.tf_record' ),
300
318
is_training = False )
@@ -309,7 +327,7 @@ def _append_feature(feature, is_padding):
309
327
# of examples must be a multiple of the batch size, or else examples
310
328
# will get dropped. So we pad with fake examples which are ignored
311
329
# later on.
312
- dataset_size = squad_lib . convert_examples_to_features (
330
+ kwargs = dict (
313
331
examples = eval_examples ,
314
332
tokenizer = tokenizer ,
315
333
max_seq_length = input_meta_data ['max_seq_length' ],
@@ -318,6 +336,11 @@ def _append_feature(feature, is_padding):
318
336
is_training = False ,
319
337
output_fn = _append_feature ,
320
338
batch_size = FLAGS .predict_batch_size )
339
+
340
+ # squad_lib_sp requires one more argument 'do_lower_case'.
341
+ if squad_lib == squad_lib_sp :
342
+ kwargs ['do_lower_case' ] = FLAGS .do_lower_case
343
+ dataset_size = squad_lib .convert_examples_to_features (** kwargs )
321
344
eval_writer .close ()
322
345
323
346
logging .info ('***** Running predictions *****' )
@@ -358,12 +381,10 @@ def export_squad(model_export_path, input_meta_data):
358
381
"""
359
382
if not model_export_path :
360
383
raise ValueError ('Export path is not specified: %s' % model_export_path )
361
- bert_config = modeling . BertConfig .from_json_file (FLAGS . bert_config_file )
362
-
384
+ bert_config = MODEL_CLASSES [ FLAGS . model_type ][ 0 ] .from_json_file (
385
+ FLAGS . bert_config_file )
363
386
squad_model , _ = bert_models .squad_model (
364
- bert_config ,
365
- input_meta_data ['max_seq_length' ],
366
- float_type = tf .float32 )
387
+ bert_config , input_meta_data ['max_seq_length' ], float_type = tf .float32 )
367
388
model_saving_utils .export_bert_model (
368
389
model_export_path , model = squad_model , checkpoint_dir = FLAGS .model_dir )
369
390
0 commit comments