@@ -258,10 +258,9 @@ def run_ncf(_):
258
258
"val_HR_METRIC" , desired_value = FLAGS .hr_threshold )
259
259
callbacks .append (early_stopping_callback )
260
260
261
- (train_input_dataset , eval_input_dataset ,
262
- num_train_steps , num_eval_steps ) = \
263
- (ncf_input_pipeline .create_ncf_input_data (
264
- params , producer , input_meta_data , strategy ))
261
+ (train_input_dataset , eval_input_dataset , num_train_steps ,
262
+ num_eval_steps ) = ncf_input_pipeline .create_ncf_input_data (
263
+ params , producer , input_meta_data , strategy )
265
264
steps_per_epoch = None if generate_input_online else num_train_steps
266
265
267
266
with distribute_utils .get_strategy_scope (strategy ):
@@ -307,7 +306,8 @@ def run_ncf(_):
307
306
if not FLAGS .ml_perf :
308
307
# Create Tensorboard summary and checkpoint callbacks.
309
308
summary_dir = os .path .join (FLAGS .model_dir , "summaries" )
310
- summary_callback = tf .keras .callbacks .TensorBoard (summary_dir )
309
+ summary_callback = tf .keras .callbacks .TensorBoard (
310
+ summary_dir , profile_batch = 0 )
311
311
checkpoint_path = os .path .join (FLAGS .model_dir , "checkpoint" )
312
312
checkpoint_callback = tf .keras .callbacks .ModelCheckpoint (
313
313
checkpoint_path , save_weights_only = True )
0 commit comments