Skip to content

Commit 88b2a35

Browse files
Adds a feature to process a dictionary for input_paths in DataConfig, to allow combining multiple datasets using a user defined combine_fn.
PiperOrigin-RevId: 381363688
1 parent 50ebc68 commit 88b2a35

File tree

2 files changed

+79
-51
lines changed

2 files changed

+79
-51
lines changed

official/core/config_definitions.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@ class DataConfig(base_config.Config):
2929
"""The base configuration for building datasets.
3030
3131
Attributes:
32-
input_path: The path to the input. It can be either (1) a str indicating
33-
a file path/pattern, or (2) a str indicating multiple file paths/patterns
34-
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or
35-
(3) a list of str, each of which is a file path/pattern or multiple file
36-
paths/patterns separated by comma.
37-
It should not be specified when the following `tfds_name` is specified.
32+
input_path: The path to the input. It can be either (1) a str indicating a
33+
file path/pattern, or (2) a str indicating multiple file paths/patterns
34+
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or (3) a list of
35+
str, each of which is a file path/pattern or multiple file paths/patterns
36+
separated by comma, or (4) a dictionary of the previous three approaches
37+
for more advanced data mixing using named access. It should not be
38+
specified when the following `tfds_name` is specified.
3839
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
3940
specified when the above `input_path` is specified.
4041
tfds_split: A str indicating which split of the data to load from TFDS. It
@@ -46,8 +47,8 @@ class DataConfig(base_config.Config):
4647
shuffle_buffer_size: The buffer size used for shuffling training data.
4748
cache: Whether to cache dataset examples. If `True`, we will cache the
4849
dataset after applying the decode_fn and parse_fn. It can be used to avoid
49-
re-reading from disk, re-decoding and re-parsing the example on the
50-
second epoch, but it requires significant memory overhead.
50+
re-reading from disk, re-decoding and re-parsing the example on the second
51+
epoch, but it requires significant memory overhead.
5152
cycle_length: The number of files that will be processed concurrently when
5253
interleaving files.
5354
block_length: The number of consecutive elements to produce from each input
@@ -59,11 +60,10 @@ class DataConfig(base_config.Config):
5960
tf_data_service_address: The URI of a tf.data service to offload
6061
preprocessing onto during training. The URI should be in the format
6162
"protocol://address", e.g. "grpc://tf-data-service:5050". It can be
62-
overridden by `FLAGS.tf_data_service` flag in the binary.
63-
tf_data_service_job_name: The name of the tf.data service job. This
64-
argument makes it possible for multiple datasets to share the same job.
65-
The default behavior is that the dataset creates anonymous, exclusively
66-
owned jobs.
63+
overridden by `FLAGS.tf_data_service` flag in the binary.
64+
tf_data_service_job_name: The name of the tf.data service job. This argument
65+
makes it possible for multiple datasets to share the same job. The default
66+
behavior is that the dataset creates anonymous, exclusively owned jobs.
6767
tfds_data_dir: A str specifying the directory to read/write TFDS data.
6868
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
6969
returned tf.data.Dataset will have a 2-tuple structure (input, label)
@@ -75,7 +75,7 @@ class DataConfig(base_config.Config):
7575
performance.
7676
seed: An optional seed to use for deterministic shuffling/preprocessing.
7777
"""
78-
input_path: Union[Sequence[str], str] = ""
78+
input_path: Union[Sequence[str], str, base_config.Config] = ""
7979
tfds_name: str = ""
8080
tfds_split: str = ""
8181
global_batch_size: int = 0

official/core/input_reader.py

Lines changed: 65 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""A common dataset reader."""
1616
import random
17-
from typing import Any, Callable, List, Optional
17+
from typing import Any, Callable, List, Optional, Union, Dict, Sequence
1818

1919
from absl import logging
2020
import tensorflow as tf
@@ -45,6 +45,7 @@ def __init__(self,
4545
params: cfg.DataConfig,
4646
dataset_fn=tf.data.TFRecordDataset,
4747
decoder_fn: Optional[Callable[..., Any]] = None,
48+
combine_fn: Optional[Callable[..., Any]] = None,
4849
sample_fn: Optional[Callable[..., Any]] = None,
4950
parser_fn: Optional[Callable[..., Any]] = None,
5051
transform_and_batch_fn: Optional[Callable[
@@ -59,6 +60,9 @@ def __init__(self,
5960
example, it can be `tf.data.TFRecordDataset`.
6061
decoder_fn: An optional `callable` that takes the serialized data string
6162
and decodes them into the raw tensor dictionary.
63+
combine_fn: An optional `callable` that takes a dictionarty of
64+
`tf.data.Dataset` objects as input and outputs a combined dataset. It
65+
will be executed after the decoder_fn and before the sample_fn.
6266
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
6367
input and outputs the transformed dataset. It performs sampling on the
6468
decoded raw tensors dict before the parser_fn.
@@ -78,10 +82,23 @@ def __init__(self,
7882
raise ValueError('At most one of `input_path` and `tfds_name` can be '
7983
'specified, but got %s and %s.' %
8084
(params.input_path, params.tfds_name))
85+
86+
if isinstance(params.input_path,
87+
cfg.base_config.Config) and combine_fn is None:
88+
raise ValueError(
89+
'A `combine_fn` is required if the `input_path` is a dictionary.')
90+
8191
self._tfds_builder = None
82-
self._matched_files = []
92+
self._matched_files = None
8393
if params.input_path:
84-
self._matched_files = self._match_files(params.input_path)
94+
# we want to combine / mix datasets
95+
if isinstance(params.input_path, cfg.base_config.Config):
96+
self._matched_files = {}
97+
for k, v in params.input_path.as_dict().items():
98+
self._matched_files[k] = self._match_files(v)
99+
# single dataset
100+
else:
101+
self._matched_files = self._match_files(params.input_path)
85102
else:
86103
# Read dataset from TFDS.
87104
if not params.tfds_split:
@@ -106,6 +123,7 @@ def __init__(self,
106123

107124
self._dataset_fn = dataset_fn
108125
self._decoder_fn = decoder_fn
126+
self._combine_fn = combine_fn
109127
self._sample_fn = sample_fn
110128
self._parser_fn = parser_fn
111129
self._transform_and_batch_fn = transform_and_batch_fn
@@ -131,7 +149,7 @@ def __init__(self,
131149
self._enable_round_robin_tf_data_service = params.get(
132150
'enable_round_robin_tf_data_service', False)
133151

134-
def _match_files(self, input_path: str) -> List[str]:
152+
def _match_files(self, input_path: Union[Sequence[str], str]) -> List[str]:
135153
"""Matches files from an input_path."""
136154
matched_files = []
137155
# Read dataset from files.
@@ -195,8 +213,8 @@ def _shard_files_then_read(
195213

196214
# Do not enable sharding if tf.data service is enabled, as sharding will be
197215
# handled inside tf.data service.
198-
if self._sharding and input_context and (
199-
input_context.num_input_pipelines > 1):
216+
if self._sharding and input_context and (input_context.num_input_pipelines >
217+
1):
200218
dataset = dataset.shard(input_context.num_input_pipelines,
201219
input_context.input_pipeline_id)
202220

@@ -231,8 +249,8 @@ def _read_files_then_shard(
231249
dataset = dataset.with_options(options)
232250
# Do not enable sharding if tf.data service is enabled, as sharding will be
233251
# handled inside tf.data service.
234-
if self._sharding and input_context and (
235-
input_context.num_input_pipelines > 1):
252+
if self._sharding and input_context and (input_context.num_input_pipelines >
253+
1):
236254
dataset = dataset.shard(input_context.num_input_pipelines,
237255
input_context.input_pipeline_id)
238256

@@ -281,42 +299,53 @@ def tfds_info(self) -> tfds.core.DatasetInfo:
281299

282300
def _read_decode_and_parse_dataset(
283301
self,
284-
matched_files: List[str],
302+
matched_files: Union[Dict[str, List[str]], List[str]],
285303
dataset_fn,
286304
batch_size: int,
287305
input_context: Optional[tf.distribute.InputContext] = None,
288306
tfds_builder: bool = False) -> tf.data.Dataset:
289307
"""Returns a tf.data.Dataset object after reading, decoding, and parsing."""
308+
309+
def _files_to_dataset(files: List[str]) -> tf.data.Dataset:
310+
if len(files) > 1:
311+
if input_context and (len(files) < input_context.num_input_pipelines):
312+
logging.warn(
313+
'The number of files %d is less than the number of input pipelines '
314+
'%d. We will send all input files to every worker. '
315+
'Please consider sharding your data into more files.', len(files),
316+
input_context.num_input_pipelines)
317+
return self._read_files_then_shard(files, dataset_fn, input_context)
318+
else:
319+
return self._shard_files_then_read(files, dataset_fn, input_context)
320+
elif len(files) == 1:
321+
return self._read_files_then_shard(files, dataset_fn, input_context)
322+
else:
323+
raise ValueError('It is unexpected that `tfds_builder` is None and '
324+
'there is also no `files`.')
325+
326+
def _shuffle_and_decode(ds):
327+
# If cache is enabled, we will call `shuffle()` later after `cache()`.
328+
if self._is_training and not self._cache:
329+
ds = ds.shuffle(self._shuffle_buffer_size, seed=self._seed)
330+
# Decode
331+
ds = _maybe_map_fn(ds, self._decoder_fn)
332+
return ds
333+
290334
if tfds_builder:
291335
dataset = self._read_tfds(input_context)
292-
elif len(matched_files) > 1:
293-
if input_context and (len(matched_files) <
294-
input_context.num_input_pipelines):
295-
logging.warn(
296-
'The number of files %d is less than the number of input pipelines '
297-
'%d. We will send all input files to every worker. '
298-
'Please consider sharding your data into more files.',
299-
len(matched_files), input_context.num_input_pipelines)
300-
dataset = self._read_files_then_shard(matched_files,
301-
dataset_fn,
302-
input_context)
303-
else:
304-
dataset = self._shard_files_then_read(matched_files,
305-
dataset_fn,
306-
input_context)
307-
elif len(matched_files) == 1:
308-
dataset = self._read_files_then_shard(matched_files,
309-
dataset_fn,
310-
input_context)
336+
dataset = _shuffle_and_decode(dataset)
337+
elif isinstance(matched_files, (list, tuple)):
338+
dataset = _files_to_dataset(matched_files)
339+
dataset = _shuffle_and_decode(dataset)
340+
elif isinstance(matched_files, dict):
341+
datasets = {}
342+
for k, fs in matched_files.items():
343+
datasets[k] = _files_to_dataset(fs)
344+
datasets[k] = _shuffle_and_decode(datasets[k])
345+
dataset = self._combine_fn(datasets)
311346
else:
312-
raise ValueError('It is unexpected that `tfds_builder` is None and '
313-
'there is also no `matched_files`.')
314-
315-
# If cache is enabled, we will call `shuffle()` later after `cache()`.
316-
if self._is_training and not self._cache:
317-
dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
347+
raise ValueError('`matched_files` should be a list or dict.')
318348

319-
dataset = _maybe_map_fn(dataset, self._decoder_fn)
320349
if self._sample_fn is not None:
321350
dataset = dataset.apply(self._sample_fn)
322351
dataset = _maybe_map_fn(dataset, self._parser_fn)
@@ -333,8 +362,7 @@ def _read_decode_and_parse_dataset(
333362
per_replica_batch_size = input_context.get_per_replica_batch_size(
334363
batch_size) if input_context else batch_size
335364
dataset = dataset.batch(
336-
per_replica_batch_size, drop_remainder=self._drop_remainder
337-
)
365+
per_replica_batch_size, drop_remainder=self._drop_remainder)
338366

339367
return dataset
340368

0 commit comments

Comments
 (0)