Skip to content

Commit dc93d9e

Browse files
Internal change
PiperOrigin-RevId: 273562498
1 parent d0f21f2 commit dc93d9e

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

official/vision/image_classification/resnet_imagenet_main.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,12 @@ def run(flags_obj):
232232
validation_freq=flags_obj.epochs_between_evals,
233233
verbose=2)
234234
if flags_obj.enable_checkpoint_and_export:
235-
# Keras model.save assumes a float32 input designature.
236-
export_path = os.path.join(flags_obj.model_dir, 'saved_model')
237-
model.save(export_path, include_optimizer=False)
235+
if dtype == tf.bfloat16:
236+
logging.warning("Keras model.save does not support bfloat16 dtype.")
237+
else:
238+
# Keras model.save assumes a float32 input designature.
239+
export_path = os.path.join(flags_obj.model_dir, 'saved_model')
240+
model.save(export_path, include_optimizer=False)
238241

239242
eval_output = None
240243
if not flags_obj.skip_eval:

0 commit comments

Comments
 (0)