|
34 | 34 | def create_dataset_from_tf_record_files(input_file_pattern,
|
35 | 35 | pre_batch_size,
|
36 | 36 | batch_size,
|
37 |
| - is_training=True): |
| 37 | + is_training=True, |
| 38 | + rebatch=False): |
38 | 39 | """Creates dataset from (tf)records files for training/evaluation."""
|
39 | 40 |
|
40 | 41 | files = tf.data.Dataset.list_files(input_file_pattern, shuffle=is_training)
|
@@ -62,6 +63,13 @@ def make_dataset(files_dataset, shard_index):
|
62 | 63 | map_fn,
|
63 | 64 | cycle_length=NUM_SHARDS,
|
64 | 65 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
| 66 | + |
| 67 | + if rebatch: |
| 68 | + # A workaround for TPU Pod evaluation dataset. |
| 69 | + # TODO (b/162341937) remove once it's fixed. |
| 70 | + dataset = dataset.unbatch() |
| 71 | + dataset = dataset.batch(pre_batch_size) |
| 72 | + |
65 | 73 | dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
|
66 | 74 | return dataset
|
67 | 75 |
|
@@ -162,12 +170,18 @@ def create_ncf_input_data(params,
|
162 | 170 | params["train_dataset_path"],
|
163 | 171 | input_meta_data["train_prebatch_size"],
|
164 | 172 | params["batch_size"],
|
165 |
| - is_training=True) |
| 173 | + is_training=True, |
| 174 | + rebatch=False) |
| 175 | + |
| 176 | + # Re-batch evaluation dataset for TPU Pods. |
| 177 | + # TODO (b/162341937) remove once it's fixed. |
| 178 | + eval_rebatch = (params["use_tpu"] and strategy.num_replicas_in_sync > 8) |
166 | 179 | eval_dataset = create_dataset_from_tf_record_files(
|
167 | 180 | params["eval_dataset_path"],
|
168 | 181 | input_meta_data["eval_prebatch_size"],
|
169 | 182 | params["eval_batch_size"],
|
170 |
| - is_training=False) |
| 183 | + is_training=False, |
| 184 | + rebatch=eval_rebatch) |
171 | 185 |
|
172 | 186 | num_train_steps = int(input_meta_data["num_train_steps"])
|
173 | 187 | num_eval_steps = int(input_meta_data["num_eval_steps"])
|
|
0 commit comments