14
14
15
15
"""A common dataset reader."""
16
16
import random
17
- from typing import Any , Callable , List , Optional
17
+ from typing import Any , Callable , List , Optional , Union , Dict , Sequence
18
18
19
19
from absl import logging
20
20
import tensorflow as tf
@@ -45,6 +45,7 @@ def __init__(self,
45
45
params : cfg .DataConfig ,
46
46
dataset_fn = tf .data .TFRecordDataset ,
47
47
decoder_fn : Optional [Callable [..., Any ]] = None ,
48
+ combine_fn : Optional [Callable [..., Any ]] = None ,
48
49
sample_fn : Optional [Callable [..., Any ]] = None ,
49
50
parser_fn : Optional [Callable [..., Any ]] = None ,
50
51
transform_and_batch_fn : Optional [Callable [
@@ -59,6 +60,9 @@ def __init__(self,
59
60
example, it can be `tf.data.TFRecordDataset`.
60
61
decoder_fn: An optional `callable` that takes the serialized data string
61
62
and decodes them into the raw tensor dictionary.
63
+ combine_fn: An optional `callable` that takes a dictionarty of
64
+ `tf.data.Dataset` objects as input and outputs a combined dataset. It
65
+ will be executed after the decoder_fn and before the sample_fn.
62
66
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
63
67
input and outputs the transformed dataset. It performs sampling on the
64
68
decoded raw tensors dict before the parser_fn.
@@ -78,10 +82,23 @@ def __init__(self,
78
82
raise ValueError ('At most one of `input_path` and `tfds_name` can be '
79
83
'specified, but got %s and %s.' %
80
84
(params .input_path , params .tfds_name ))
85
+
86
+ if isinstance (params .input_path ,
87
+ cfg .base_config .Config ) and combine_fn is None :
88
+ raise ValueError (
89
+ 'A `combine_fn` is required if the `input_path` is a dictionary.' )
90
+
81
91
self ._tfds_builder = None
82
- self ._matched_files = []
92
+ self ._matched_files = None
83
93
if params .input_path :
84
- self ._matched_files = self ._match_files (params .input_path )
94
+ # we want to combine / mix datasets
95
+ if isinstance (params .input_path , cfg .base_config .Config ):
96
+ self ._matched_files = {}
97
+ for k , v in params .input_path .as_dict ().items ():
98
+ self ._matched_files [k ] = self ._match_files (v )
99
+ # single dataset
100
+ else :
101
+ self ._matched_files = self ._match_files (params .input_path )
85
102
else :
86
103
# Read dataset from TFDS.
87
104
if not params .tfds_split :
@@ -106,6 +123,7 @@ def __init__(self,
106
123
107
124
self ._dataset_fn = dataset_fn
108
125
self ._decoder_fn = decoder_fn
126
+ self ._combine_fn = combine_fn
109
127
self ._sample_fn = sample_fn
110
128
self ._parser_fn = parser_fn
111
129
self ._transform_and_batch_fn = transform_and_batch_fn
@@ -131,7 +149,7 @@ def __init__(self,
131
149
self ._enable_round_robin_tf_data_service = params .get (
132
150
'enable_round_robin_tf_data_service' , False )
133
151
134
- def _match_files (self , input_path : str ) -> List [str ]:
152
+ def _match_files (self , input_path : Union [ Sequence [ str ], str ] ) -> List [str ]:
135
153
"""Matches files from an input_path."""
136
154
matched_files = []
137
155
# Read dataset from files.
@@ -195,8 +213,8 @@ def _shard_files_then_read(
195
213
196
214
# Do not enable sharding if tf.data service is enabled, as sharding will be
197
215
# handled inside tf.data service.
198
- if self ._sharding and input_context and (
199
- input_context . num_input_pipelines > 1 ):
216
+ if self ._sharding and input_context and (input_context . num_input_pipelines >
217
+ 1 ):
200
218
dataset = dataset .shard (input_context .num_input_pipelines ,
201
219
input_context .input_pipeline_id )
202
220
@@ -231,8 +249,8 @@ def _read_files_then_shard(
231
249
dataset = dataset .with_options (options )
232
250
# Do not enable sharding if tf.data service is enabled, as sharding will be
233
251
# handled inside tf.data service.
234
- if self ._sharding and input_context and (
235
- input_context . num_input_pipelines > 1 ):
252
+ if self ._sharding and input_context and (input_context . num_input_pipelines >
253
+ 1 ):
236
254
dataset = dataset .shard (input_context .num_input_pipelines ,
237
255
input_context .input_pipeline_id )
238
256
@@ -281,42 +299,53 @@ def tfds_info(self) -> tfds.core.DatasetInfo:
281
299
282
300
def _read_decode_and_parse_dataset (
283
301
self ,
284
- matched_files : List [str ],
302
+ matched_files : Union [ Dict [ str , List [str ]], List [ str ] ],
285
303
dataset_fn ,
286
304
batch_size : int ,
287
305
input_context : Optional [tf .distribute .InputContext ] = None ,
288
306
tfds_builder : bool = False ) -> tf .data .Dataset :
289
307
"""Returns a tf.data.Dataset object after reading, decoding, and parsing."""
308
+
309
+ def _files_to_dataset (files : List [str ]) -> tf .data .Dataset :
310
+ if len (files ) > 1 :
311
+ if input_context and (len (files ) < input_context .num_input_pipelines ):
312
+ logging .warn (
313
+ 'The number of files %d is less than the number of input pipelines '
314
+ '%d. We will send all input files to every worker. '
315
+ 'Please consider sharding your data into more files.' , len (files ),
316
+ input_context .num_input_pipelines )
317
+ return self ._read_files_then_shard (files , dataset_fn , input_context )
318
+ else :
319
+ return self ._shard_files_then_read (files , dataset_fn , input_context )
320
+ elif len (files ) == 1 :
321
+ return self ._read_files_then_shard (files , dataset_fn , input_context )
322
+ else :
323
+ raise ValueError ('It is unexpected that `tfds_builder` is None and '
324
+ 'there is also no `files`.' )
325
+
326
+ def _shuffle_and_decode (ds ):
327
+ # If cache is enabled, we will call `shuffle()` later after `cache()`.
328
+ if self ._is_training and not self ._cache :
329
+ ds = ds .shuffle (self ._shuffle_buffer_size , seed = self ._seed )
330
+ # Decode
331
+ ds = _maybe_map_fn (ds , self ._decoder_fn )
332
+ return ds
333
+
290
334
if tfds_builder :
291
335
dataset = self ._read_tfds (input_context )
292
- elif len (matched_files ) > 1 :
293
- if input_context and (len (matched_files ) <
294
- input_context .num_input_pipelines ):
295
- logging .warn (
296
- 'The number of files %d is less than the number of input pipelines '
297
- '%d. We will send all input files to every worker. '
298
- 'Please consider sharding your data into more files.' ,
299
- len (matched_files ), input_context .num_input_pipelines )
300
- dataset = self ._read_files_then_shard (matched_files ,
301
- dataset_fn ,
302
- input_context )
303
- else :
304
- dataset = self ._shard_files_then_read (matched_files ,
305
- dataset_fn ,
306
- input_context )
307
- elif len (matched_files ) == 1 :
308
- dataset = self ._read_files_then_shard (matched_files ,
309
- dataset_fn ,
310
- input_context )
336
+ dataset = _shuffle_and_decode (dataset )
337
+ elif isinstance (matched_files , (list , tuple )):
338
+ dataset = _files_to_dataset (matched_files )
339
+ dataset = _shuffle_and_decode (dataset )
340
+ elif isinstance (matched_files , dict ):
341
+ datasets = {}
342
+ for k , fs in matched_files .items ():
343
+ datasets [k ] = _files_to_dataset (fs )
344
+ datasets [k ] = _shuffle_and_decode (datasets [k ])
345
+ dataset = self ._combine_fn (datasets )
311
346
else :
312
- raise ValueError ('It is unexpected that `tfds_builder` is None and '
313
- 'there is also no `matched_files`.' )
314
-
315
- # If cache is enabled, we will call `shuffle()` later after `cache()`.
316
- if self ._is_training and not self ._cache :
317
- dataset = dataset .shuffle (self ._shuffle_buffer_size , seed = self ._seed )
347
+ raise ValueError ('`matched_files` should be a list or dict.' )
318
348
319
- dataset = _maybe_map_fn (dataset , self ._decoder_fn )
320
349
if self ._sample_fn is not None :
321
350
dataset = dataset .apply (self ._sample_fn )
322
351
dataset = _maybe_map_fn (dataset , self ._parser_fn )
@@ -333,8 +362,7 @@ def _read_decode_and_parse_dataset(
333
362
per_replica_batch_size = input_context .get_per_replica_batch_size (
334
363
batch_size ) if input_context else batch_size
335
364
dataset = dataset .batch (
336
- per_replica_batch_size , drop_remainder = self ._drop_remainder
337
- )
365
+ per_replica_batch_size , drop_remainder = self ._drop_remainder )
338
366
339
367
return dataset
340
368
0 commit comments