21
21
import functools
22
22
import json
23
23
import math
24
+ import os
24
25
25
26
from absl import app
26
27
from absl import flags
@@ -82,19 +83,19 @@ def classification_loss_fn(labels, logits):
82
83
return classification_loss_fn
83
84
84
85
85
- def run_customized_training (strategy ,
86
- bert_config ,
87
- input_meta_data ,
88
- model_dir ,
89
- epochs ,
90
- steps_per_epoch ,
91
- steps_per_loop ,
92
- eval_steps ,
93
- warmup_steps ,
94
- initial_lr ,
95
- init_checkpoint ,
96
- custom_callbacks = None ,
97
- run_eagerly = False ):
86
+ def run_bert_classifier (strategy ,
87
+ bert_config ,
88
+ input_meta_data ,
89
+ model_dir ,
90
+ epochs ,
91
+ steps_per_epoch ,
92
+ steps_per_loop ,
93
+ eval_steps ,
94
+ warmup_steps ,
95
+ initial_lr ,
96
+ init_checkpoint ,
97
+ custom_callbacks = None ,
98
+ run_eagerly = False ):
98
99
"""Run BERT classifier training using low-level API."""
99
100
max_seq_length = input_meta_data ['max_seq_length' ]
100
101
num_classes = input_meta_data ['num_labels' ]
@@ -144,6 +145,27 @@ def metric_fn():
144
145
return tf .keras .metrics .SparseCategoricalAccuracy (
145
146
'test_accuracy' , dtype = tf .float32 )
146
147
148
+ if FLAGS .use_keras_compile_fit :
149
+ # Start training using Keras compile/fit API.
150
+ logging .info ('Training using TF 2.0 Keras compile/fit API with '
151
+ 'distrubuted strategy.' )
152
+ return run_keras_compile_fit (
153
+ model_dir ,
154
+ strategy ,
155
+ _get_classifier_model ,
156
+ train_input_fn ,
157
+ eval_input_fn ,
158
+ loss_fn ,
159
+ metric_fn ,
160
+ init_checkpoint ,
161
+ epochs ,
162
+ steps_per_epoch ,
163
+ eval_steps ,
164
+ custom_callbacks = None )
165
+
166
+ # Use user-defined loop to start training.
167
+ logging .info ('Training using customized training loop TF 2.0 with '
168
+ 'distrubuted strategy.' )
147
169
return model_training_utils .run_customized_training_loop (
148
170
strategy = strategy ,
149
171
model_fn = _get_classifier_model ,
@@ -161,6 +183,52 @@ def metric_fn():
161
183
run_eagerly = run_eagerly )
162
184
163
185
186
+ def run_keras_compile_fit (model_dir ,
187
+ strategy ,
188
+ model_fn ,
189
+ train_input_fn ,
190
+ eval_input_fn ,
191
+ loss_fn ,
192
+ metric_fn ,
193
+ init_checkpoint ,
194
+ epochs ,
195
+ steps_per_epoch ,
196
+ eval_steps ,
197
+ custom_callbacks = None ):
198
+ """Runs BERT classifier model using Keras compile/fit API."""
199
+
200
+ with strategy .scope ():
201
+ training_dataset = train_input_fn ()
202
+ evaluation_dataset = eval_input_fn ()
203
+ bert_model , sub_model = model_fn ()
204
+ optimizer = bert_model .optimizer
205
+
206
+ if init_checkpoint :
207
+ checkpoint = tf .train .Checkpoint (model = sub_model )
208
+ checkpoint .restore (init_checkpoint ).assert_existing_objects_matched ()
209
+
210
+ bert_model .compile (optimizer = optimizer , loss = loss_fn , metrics = [metric_fn ()])
211
+
212
+ summary_callback = tf .keras .callbacks .TensorBoard (model_dir )
213
+ checkpoint_dir = os .path .join (model_dir , 'model_checkpoint.{epoch:02d}' )
214
+ checkpoint_callback = tf .keras .callbacks .ModelCheckpoint (checkpoint_dir )
215
+
216
+ if custom_callbacks is not None :
217
+ custom_callbacks += [summary_callback , checkpoint_callback ]
218
+ else :
219
+ custom_callbacks = [summary_callback , checkpoint_callback ]
220
+
221
+ bert_model .fit (
222
+ x = training_dataset ,
223
+ validation_data = evaluation_dataset ,
224
+ steps_per_epoch = steps_per_epoch ,
225
+ epochs = epochs ,
226
+ validation_steps = eval_steps ,
227
+ callbacks = custom_callbacks )
228
+
229
+ return bert_model
230
+
231
+
164
232
def export_classifier (model_export_path , input_meta_data ):
165
233
"""Exports a trained model as a `SavedModel` for inference.
166
234
@@ -203,10 +271,8 @@ def run_bert(strategy, input_meta_data):
203
271
204
272
if not strategy :
205
273
raise ValueError ('Distribution strategy has not been specified.' )
206
- # Runs customized training loop.
207
- logging .info ('Training using customized training loop TF 2.0 with distrubuted'
208
- 'strategy.' )
209
- trained_model = run_customized_training (
274
+
275
+ trained_model = run_bert_classifier (
210
276
strategy ,
211
277
bert_config ,
212
278
input_meta_data ,
0 commit comments