Skip to content

Commit 7207422

Browse files
Internal change
PiperOrigin-RevId: 273653001
1 parent dc93d9e commit 7207422

File tree

2 files changed

+87
-17
lines changed

2 files changed

+87
-17
lines changed

official/nlp/bert/common_flags.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def define_common_bert_flags():
5656
'scale_loss', False,
5757
'Whether to divide the loss by number of replica inside the per-replica '
5858
'loss function.')
59+
flags.DEFINE_boolean(
60+
'use_keras_compile_fit', False,
61+
'If True, uses Keras compile/fit() API for training logic. Otherwise '
62+
'use custom training loop.')
5963

6064
# Adds flags for mixed precision training.
6165
flags_core.define_performance(

official/nlp/bert/run_classifier.py

Lines changed: 83 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import functools
2222
import json
2323
import math
24+
import os
2425

2526
from absl import app
2627
from absl import flags
@@ -82,19 +83,19 @@ def classification_loss_fn(labels, logits):
8283
return classification_loss_fn
8384

8485

85-
def run_customized_training(strategy,
86-
bert_config,
87-
input_meta_data,
88-
model_dir,
89-
epochs,
90-
steps_per_epoch,
91-
steps_per_loop,
92-
eval_steps,
93-
warmup_steps,
94-
initial_lr,
95-
init_checkpoint,
96-
custom_callbacks=None,
97-
run_eagerly=False):
86+
def run_bert_classifier(strategy,
87+
bert_config,
88+
input_meta_data,
89+
model_dir,
90+
epochs,
91+
steps_per_epoch,
92+
steps_per_loop,
93+
eval_steps,
94+
warmup_steps,
95+
initial_lr,
96+
init_checkpoint,
97+
custom_callbacks=None,
98+
run_eagerly=False):
9899
"""Run BERT classifier training using low-level API."""
99100
max_seq_length = input_meta_data['max_seq_length']
100101
num_classes = input_meta_data['num_labels']
@@ -144,6 +145,27 @@ def metric_fn():
144145
return tf.keras.metrics.SparseCategoricalAccuracy(
145146
'test_accuracy', dtype=tf.float32)
146147

148+
if FLAGS.use_keras_compile_fit:
149+
# Start training using Keras compile/fit API.
150+
logging.info('Training using TF 2.0 Keras compile/fit API with '
151+
'distrubuted strategy.')
152+
return run_keras_compile_fit(
153+
model_dir,
154+
strategy,
155+
_get_classifier_model,
156+
train_input_fn,
157+
eval_input_fn,
158+
loss_fn,
159+
metric_fn,
160+
init_checkpoint,
161+
epochs,
162+
steps_per_epoch,
163+
eval_steps,
164+
custom_callbacks=None)
165+
166+
# Use user-defined loop to start training.
167+
logging.info('Training using customized training loop TF 2.0 with '
168+
'distrubuted strategy.')
147169
return model_training_utils.run_customized_training_loop(
148170
strategy=strategy,
149171
model_fn=_get_classifier_model,
@@ -161,6 +183,52 @@ def metric_fn():
161183
run_eagerly=run_eagerly)
162184

163185

186+
def run_keras_compile_fit(model_dir,
187+
strategy,
188+
model_fn,
189+
train_input_fn,
190+
eval_input_fn,
191+
loss_fn,
192+
metric_fn,
193+
init_checkpoint,
194+
epochs,
195+
steps_per_epoch,
196+
eval_steps,
197+
custom_callbacks=None):
198+
"""Runs BERT classifier model using Keras compile/fit API."""
199+
200+
with strategy.scope():
201+
training_dataset = train_input_fn()
202+
evaluation_dataset = eval_input_fn()
203+
bert_model, sub_model = model_fn()
204+
optimizer = bert_model.optimizer
205+
206+
if init_checkpoint:
207+
checkpoint = tf.train.Checkpoint(model=sub_model)
208+
checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
209+
210+
bert_model.compile(optimizer=optimizer, loss=loss_fn, metrics=[metric_fn()])
211+
212+
summary_callback = tf.keras.callbacks.TensorBoard(model_dir)
213+
checkpoint_dir = os.path.join(model_dir, 'model_checkpoint.{epoch:02d}')
214+
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_dir)
215+
216+
if custom_callbacks is not None:
217+
custom_callbacks += [summary_callback, checkpoint_callback]
218+
else:
219+
custom_callbacks = [summary_callback, checkpoint_callback]
220+
221+
bert_model.fit(
222+
x=training_dataset,
223+
validation_data=evaluation_dataset,
224+
steps_per_epoch=steps_per_epoch,
225+
epochs=epochs,
226+
validation_steps=eval_steps,
227+
callbacks=custom_callbacks)
228+
229+
return bert_model
230+
231+
164232
def export_classifier(model_export_path, input_meta_data):
165233
"""Exports a trained model as a `SavedModel` for inference.
166234
@@ -203,10 +271,8 @@ def run_bert(strategy, input_meta_data):
203271

204272
if not strategy:
205273
raise ValueError('Distribution strategy has not been specified.')
206-
# Runs customized training loop.
207-
logging.info('Training using customized training loop TF 2.0 with distrubuted'
208-
'strategy.')
209-
trained_model = run_customized_training(
274+
275+
trained_model = run_bert_classifier(
210276
strategy,
211277
bert_config,
212278
input_meta_data,

0 commit comments

Comments
 (0)