Skip to content

Commit 7acce35

Browse files
rachellj218tensorflower-gardener
authored andcommitted
Fix bug in test dataset generation
PiperOrigin-RevId: 273066504
1 parent f8d9c9b commit 7acce35

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

official/resnet/ctl/ctl_imagenet_main.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,14 @@ def _test_data_fn(ctx=None):
128128
input_context=ctx)
129129
return test_ds
130130

131-
if strategy:
132-
if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
133-
test_ds = strategy.experimental_distribute_datasets_from_function(_test_data_fn)
131+
if strategy:
132+
if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
133+
test_ds = strategy.experimental_distribute_datasets_from_function(
134+
_test_data_fn)
135+
else:
136+
test_ds = strategy.experimental_distribute_dataset(_test_data_fn())
134137
else:
135-
test_ds = strategy.experimental_distribute_dataset(_test_data_fn())
136-
else:
137-
test_ds = _test_data_fn()
138+
test_ds = _test_data_fn()
138139

139140
return train_ds, test_ds
140141

0 commit comments

Comments
 (0)