Skip to content

Commit 7f66d53

Browse files
saberkuntensorflower-gardener
authored andcommitted
Add seed to DataConfig. Add deterministic tests for input reader.
PiperOrigin-RevId: 364734183
1 parent 5ae21ff commit 7f66d53

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

official/core/config_definitions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class DataConfig(base_config.Config):
7373
decoding when loading dataset from TFDS. Use comma to separate multiple
7474
features. The main use case is to skip the image/video decoding for better
7575
performance.
76+
seed: An optional seed to use for deterministic shuffling/preprocessing.
7677
"""
7778
input_path: Union[Sequence[str], str] = ""
7879
tfds_name: str = ""
@@ -92,6 +93,7 @@ class DataConfig(base_config.Config):
9293
tfds_data_dir: str = ""
9394
tfds_as_supervised: bool = False
9495
tfds_skip_decoding_feature: str = ""
96+
seed: Optional[int] = None
9597

9698

9799
@dataclasses.dataclass

official/core/input_reader.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,12 @@ def __init__(self,
112112
self._postprocess_fn = postprocess_fn
113113
# When tf.data service is enabled, each data service worker should get
114114
# different random seeds. Thus, we set `seed` to None.
115-
self._seed = (None
116-
if params.enable_tf_data_service else _get_random_integer())
115+
if params.seed is not None:
116+
self._seed = params.seed
117+
elif params.enable_tf_data_service:
118+
self._seed = _get_random_integer()
119+
else:
120+
self._seed = None
117121

118122
self._enable_tf_data_service = (
119123
params.enable_tf_data_service and params.tf_data_service_address)
@@ -243,7 +247,8 @@ def _read_tfds(
243247
read_config = tfds.ReadConfig(
244248
interleave_cycle_length=self._cycle_length,
245249
interleave_block_length=self._block_length,
246-
input_context=input_context)
250+
input_context=input_context,
251+
shuffle_seed=self._seed)
247252
decoders = {}
248253
if self._tfds_skip_decoding_feature:
249254
for skip_feature in self._tfds_skip_decoding_feature.split(','):
@@ -304,7 +309,7 @@ def _read_decode_and_parse_dataset(
304309

305310
# If cache is enabled, we will call `shuffle()` later after `cache()`.
306311
if self._is_training and not self._cache:
307-
dataset = dataset.shuffle(self._shuffle_buffer_size)
312+
dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
308313

309314
dataset = _maybe_map_fn(dataset, self._decoder_fn)
310315
if self._sample_fn is not None:
@@ -315,7 +320,7 @@ def _read_decode_and_parse_dataset(
315320
dataset = dataset.cache()
316321
if self._is_training:
317322
dataset = dataset.repeat()
318-
dataset = dataset.shuffle(self._shuffle_buffer_size)
323+
dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
319324

320325
if self._transform_and_batch_fn is not None:
321326
dataset = self._transform_and_batch_fn(dataset, input_context)

0 commit comments

Comments
 (0)