Skip to content

Commit 98b0763

Browse files
Disabling Tensorboard profiling for NCF. (#9959)
PiperOrigin-RevId: 371256980 Co-authored-by: A. Unique TensorFlower <[email protected]>
1 parent 8d51b58 commit 98b0763

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

official/recommendation/ncf_keras_main.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,9 @@ def run_ncf(_):
258258
"val_HR_METRIC", desired_value=FLAGS.hr_threshold)
259259
callbacks.append(early_stopping_callback)
260260

261-
(train_input_dataset, eval_input_dataset,
262-
num_train_steps, num_eval_steps) = \
263-
(ncf_input_pipeline.create_ncf_input_data(
264-
params, producer, input_meta_data, strategy))
261+
(train_input_dataset, eval_input_dataset, num_train_steps,
262+
num_eval_steps) = ncf_input_pipeline.create_ncf_input_data(
263+
params, producer, input_meta_data, strategy)
265264
steps_per_epoch = None if generate_input_online else num_train_steps
266265

267266
with distribute_utils.get_strategy_scope(strategy):
@@ -307,7 +306,8 @@ def run_ncf(_):
307306
if not FLAGS.ml_perf:
308307
# Create Tensorboard summary and checkpoint callbacks.
309308
summary_dir = os.path.join(FLAGS.model_dir, "summaries")
310-
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
309+
summary_callback = tf.keras.callbacks.TensorBoard(
310+
summary_dir, profile_batch=0)
311311
checkpoint_path = os.path.join(FLAGS.model_dir, "checkpoint")
312312
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
313313
checkpoint_path, save_weights_only=True)

official/recommendation/ncf_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ def setUpClass(cls): # pylint: disable=invalid-name
3838
ncf_common.define_ncf_flags()
3939

4040
def setUp(self):
41+
super().setUp()
4142
self.top_k_old = rconst.TOP_K
4243
self.num_eval_negatives_old = rconst.NUM_EVAL_NEGATIVES
4344
rconst.NUM_EVAL_NEGATIVES = 2
4445

4546
def tearDown(self):
47+
super().tearDown()
4648
rconst.NUM_EVAL_NEGATIVES = self.num_eval_negatives_old
4749
rconst.TOP_K = self.top_k_old
4850

0 commit comments

Comments
 (0)