Skip to content

Commit a84f1b9

Browse files
Use distirbuted dataset for NCF evaluation. (#9009)
PiperOrigin-RevId: 323948101 Co-authored-by: A. Unique TensorFlower <[email protected]>
1 parent 3729e19 commit a84f1b9

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

official/recommendation/ncf_input_pipeline.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
def create_dataset_from_tf_record_files(input_file_pattern,
3535
pre_batch_size,
3636
batch_size,
37-
is_training=True):
37+
is_training=True,
38+
rebatch=False):
3839
"""Creates dataset from (tf)records files for training/evaluation."""
3940

4041
files = tf.data.Dataset.list_files(input_file_pattern, shuffle=is_training)
@@ -62,6 +63,13 @@ def make_dataset(files_dataset, shard_index):
6263
map_fn,
6364
cycle_length=NUM_SHARDS,
6465
num_parallel_calls=tf.data.experimental.AUTOTUNE)
66+
67+
if rebatch:
68+
# A workaround for TPU Pod evaluation dataset.
69+
# TODO (b/162341937) remove once it's fixed.
70+
dataset = dataset.unbatch()
71+
dataset = dataset.batch(pre_batch_size)
72+
6573
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
6674
return dataset
6775

@@ -162,12 +170,18 @@ def create_ncf_input_data(params,
162170
params["train_dataset_path"],
163171
input_meta_data["train_prebatch_size"],
164172
params["batch_size"],
165-
is_training=True)
173+
is_training=True,
174+
rebatch=False)
175+
176+
# Re-batch evaluation dataset for TPU Pods.
177+
# TODO (b/162341937) remove once it's fixed.
178+
eval_rebatch = (params["use_tpu"] and strategy.num_replicas_in_sync > 8)
166179
eval_dataset = create_dataset_from_tf_record_files(
167180
params["eval_dataset_path"],
168181
input_meta_data["eval_prebatch_size"],
169182
params["eval_batch_size"],
170-
is_training=False)
183+
is_training=False,
184+
rebatch=eval_rebatch)
171185

172186
num_train_steps = int(input_meta_data["num_train_steps"])
173187
num_eval_steps = int(input_meta_data["num_eval_steps"])

official/recommendation/ncf_keras_main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def run_ncf(_):
235235

236236
params = ncf_common.parse_flags(FLAGS)
237237
params["distribute_strategy"] = strategy
238+
params["use_tpu"] = (FLAGS.distribution_strategy == "tpu")
238239

239240
if params["use_tpu"] and not params["keras_use_ctl"]:
240241
logging.error("Custom training loop must be used when using TPUStrategy.")
@@ -491,7 +492,8 @@ def step_fn(features):
491492
logging.info("Done training epoch %s, epoch loss=%.3f", epoch + 1,
492493
train_loss)
493494

494-
eval_input_iterator = iter(eval_input_dataset)
495+
eval_input_iterator = iter(
496+
strategy.experimental_distribute_dataset(eval_input_dataset))
495497

496498
hr_sum = 0.0
497499
hr_count = 0.0

0 commit comments

Comments
 (0)