File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed
official/vision/image_classification Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -232,9 +232,12 @@ def run(flags_obj):
232
232
validation_freq = flags_obj .epochs_between_evals ,
233
233
verbose = 2 )
234
234
if flags_obj .enable_checkpoint_and_export :
235
- # Keras model.save assumes a float32 input designature.
236
- export_path = os .path .join (flags_obj .model_dir , 'saved_model' )
237
- model .save (export_path , include_optimizer = False )
235
+ if dtype == tf .bfloat16 :
236
+ logging .warning ("Keras model.save does not support bfloat16 dtype." )
237
+ else :
238
+ # Keras model.save assumes a float32 input designature.
239
+ export_path = os .path .join (flags_obj .model_dir , 'saved_model' )
240
+ model .save (export_path , include_optimizer = False )
238
241
239
242
eval_output = None
240
243
if not flags_obj .skip_eval :
You can’t perform that action at this time.
0 commit comments