@@ -112,8 +112,12 @@ def __init__(self,
112
112
self ._postprocess_fn = postprocess_fn
113
113
# When tf.data service is enabled, each data service worker should get
114
114
# different random seeds. Thus, we set `seed` to None.
115
- self ._seed = (None
116
- if params .enable_tf_data_service else _get_random_integer ())
115
+ if params .seed is not None :
116
+ self ._seed = params .seed
117
+ elif params .enable_tf_data_service :
118
+ self ._seed = _get_random_integer ()
119
+ else :
120
+ self ._seed = None
117
121
118
122
self ._enable_tf_data_service = (
119
123
params .enable_tf_data_service and params .tf_data_service_address )
@@ -243,7 +247,8 @@ def _read_tfds(
243
247
read_config = tfds .ReadConfig (
244
248
interleave_cycle_length = self ._cycle_length ,
245
249
interleave_block_length = self ._block_length ,
246
- input_context = input_context )
250
+ input_context = input_context ,
251
+ shuffle_seed = self ._seed )
247
252
decoders = {}
248
253
if self ._tfds_skip_decoding_feature :
249
254
for skip_feature in self ._tfds_skip_decoding_feature .split (',' ):
@@ -304,7 +309,7 @@ def _read_decode_and_parse_dataset(
304
309
305
310
# If cache is enabled, we will call `shuffle()` later after `cache()`.
306
311
if self ._is_training and not self ._cache :
307
- dataset = dataset .shuffle (self ._shuffle_buffer_size )
312
+ dataset = dataset .shuffle (self ._shuffle_buffer_size , seed = self . _seed )
308
313
309
314
dataset = _maybe_map_fn (dataset , self ._decoder_fn )
310
315
if self ._sample_fn is not None :
@@ -315,7 +320,7 @@ def _read_decode_and_parse_dataset(
315
320
dataset = dataset .cache ()
316
321
if self ._is_training :
317
322
dataset = dataset .repeat ()
318
- dataset = dataset .shuffle (self ._shuffle_buffer_size )
323
+ dataset = dataset .shuffle (self ._shuffle_buffer_size , seed = self . _seed )
319
324
320
325
if self ._transform_and_batch_fn is not None :
321
326
dataset = self ._transform_and_batch_fn (dataset , input_context )
0 commit comments