Skip to content

Commit fb68321

Browse files
author
The TensorFlow Datasets Authors
committed
Add a beam writer that doesn't shuffle
PiperOrigin-RevId: 691542774
1 parent 1940fa6 commit fb68321

14 files changed

+409
-136
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,10 @@ def download_and_prepare(
623623
data_path = self.data_path
624624
data_exists = data_path.exists()
625625

626+
# Saving nondeterministic_order in the DatasetInfo for documentation.
627+
if download_config.nondeterministic_order:
628+
self.info.set_nondeterministic_order(True)
629+
626630
if download_config.download_mode == UPDATE_DATASET_INFO:
627631
self._update_dataset_info()
628632
return
@@ -1427,11 +1431,13 @@ def _get_filename_template(
14271431
self, split_name: str
14281432
) -> naming.ShardedFileTemplate:
14291433
"""Returns a filename template for the given split."""
1434+
if self.info.file_format is None:
1435+
raise ValueError("File format is not set!")
14301436
return naming.ShardedFileTemplate(
14311437
split=split_name,
14321438
dataset_name=self.name,
14331439
data_dir=self.data_path,
1434-
filetype_suffix=self.info.file_format.file_suffix, # pytype: disable=attribute-error
1440+
filetype_suffix=self.info.file_format.file_suffix,
14351441
)
14361442

14371443

@@ -1729,6 +1735,7 @@ def _generate_splits(
17291735
generator=generator,
17301736
filename_template=filename_template,
17311737
disable_shuffling=self.info.disable_shuffling,
1738+
nondeterministic_order=download_config.nondeterministic_order,
17321739
)
17331740
split_info_futures.append(future)
17341741

tensorflow_datasets/core/dataset_builder_beam_test.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,16 @@ class DummyBeamDataset(dataset_builder.GeneratorBasedBuilder):
3939
'valid_725': 725,
4040
}
4141

42+
FEATURE_DICT = features.FeaturesDict({
43+
'image': features.Image(shape=(16, 16, 1)),
44+
'label': features.ClassLabel(names=['dog', 'cat']),
45+
'id': tf.int32,
46+
})
47+
4248
def _info(self):
4349
return dataset_info.DatasetInfo(
4450
builder=self,
45-
features=features.FeaturesDict({
46-
'image': features.Image(shape=(16, 16, 1)),
47-
'label': features.ClassLabel(names=['dog', 'cat']),
48-
'id': tf.int32,
49-
}),
51+
features=self.FEATURE_DICT,
5052
supervised_keys=('x', 'x'),
5153
metadata=dataset_info.BeamMetadataDict(),
5254
)
@@ -71,6 +73,18 @@ def _generate_examples(self, num_examples):
7173
return examples
7274

7375

76+
class UnshuffledDummyBeamDataset(DummyBeamDataset):
77+
78+
def _info(self) -> dataset_info.DatasetInfo:
79+
return dataset_info.DatasetInfo(
80+
builder=self,
81+
features=self.FEATURE_DICT,
82+
supervised_keys=('x', 'x'),
83+
metadata=dataset_info.BeamMetadataDict(),
84+
disable_shuffling=True,
85+
)
86+
87+
7488
class CommonPipelineDummyBeamDataset(DummyBeamDataset):
7589
EXPECTED_METADATA = {
7690
'label_sum_1000': 500,
@@ -151,12 +165,21 @@ def _compute_mean(examples):
151165
)
152166

153167

168+
def get_id(ex):
169+
return ex['id']
170+
171+
154172
def make_default_config():
155173
return download.DownloadConfig()
156174

157175

158176
@pytest.mark.parametrize(
159-
'dataset_cls', [DummyBeamDataset, CommonPipelineDummyBeamDataset]
177+
'dataset_cls',
178+
[
179+
DummyBeamDataset,
180+
CommonPipelineDummyBeamDataset,
181+
UnshuffledDummyBeamDataset,
182+
],
160183
)
161184
@pytest.mark.parametrize(
162185
'make_dl_config',
@@ -178,29 +201,23 @@ def test_beam_datasets(
178201
assert data_path.exists() # Dataset has been generated
179202

180203
# Check number of shards/generated files
181-
_test_shards(
182-
data_path,
183-
pattern='%s-test.tfrecord-{:05}-of-{:05}' % dataset_name,
184-
# Liquid sharding is not guaranteed to always use the same number.
185-
num_shards=builder.info.splits['test'].num_shards,
186-
)
187-
_test_shards(
188-
data_path,
189-
pattern='%s-train.tfrecord-{:05}-of-{:05}' % dataset_name,
190-
num_shards=1,
191-
)
204+
for split in ['test', 'train']:
205+
_test_shards(
206+
data_path,
207+
pattern='%s-%s.tfrecord-{:05}-of-{:05}' % (dataset_name, split),
208+
num_shards=builder.info.splits[split].num_shards,
209+
)
192210

193211
ds = dataset_utils.as_numpy(builder.as_dataset())
194212

195-
def get_id(ex):
196-
return ex['id']
197-
213+
test_examples = list(ds['test'])
214+
train_examples = list(ds['train'])
198215
_assert_values_equal(
199-
sorted(list(ds['test']), key=get_id),
216+
sorted(test_examples, key=get_id),
200217
sorted([_gen_example(i)[1] for i in range(725)], key=get_id),
201218
)
202219
_assert_values_equal(
203-
sorted(list(ds['train']), key=get_id),
220+
sorted(train_examples, key=get_id),
204221
sorted([_gen_example(i)[1] for i in range(1000)], key=get_id),
205222
)
206223

tensorflow_datasets/core/dataset_info.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def __init__(
186186
features: feature_lib.FeatureConnector | None = None,
187187
supervised_keys: SupervisedKeysType | None = None,
188188
disable_shuffling: bool = False,
189+
nondeterministic_order: bool = False,
189190
homepage: str | None = None,
190191
citation: str | None = None,
191192
metadata: Metadata | None = None,
@@ -228,7 +229,11 @@ def __init__(
228229
229230
Note that selecting features in nested `tfds.features.FeaturesDict`
230231
objects is not supported.
231-
disable_shuffling: `bool`, specify whether to shuffle the examples.
232+
disable_shuffling: `bool`, specifies whether to shuffle the examples.
233+
nondeterministic_order: `bool`, if True and the dataset uses beam, it will
234+
use `NoShuffleBeamWriter` which does not assure deterministic
235+
shuffling when writing' examples to disk. This might result in quicker
236+
dataset preparation.
232237
homepage: `str`, optional, the homepage for this dataset.
233238
citation: `str`, optional, the citation to use for this dataset.
234239
metadata: `tfds.core.Metadata`, additonal object which will be
@@ -268,6 +273,7 @@ def __init__(
268273
version=str(self._identity.version),
269274
release_notes=self._identity.release_notes,
270275
disable_shuffling=disable_shuffling,
276+
nondeterministic_order=nondeterministic_order,
271277
config_name=self._identity.config_name,
272278
config_description=self._identity.config_description,
273279
config_tags=self._identity.config_tags,
@@ -342,6 +348,7 @@ def from_proto(
342348
features=features,
343349
supervised_keys=supervised_keys,
344350
disable_shuffling=proto.disable_shuffling,
351+
nondeterministic_order=proto.nondeterministic_order,
345352
citation=proto.citation,
346353
license=proto.redistribution_info.license,
347354
split_dict=splits_lib.SplitDict.from_proto(
@@ -400,6 +407,13 @@ def release_notes(self) -> dict[str, str] | None:
400407
def disable_shuffling(self) -> bool:
401408
return self.as_proto.disable_shuffling
402409

410+
@property
411+
def nondeterministic_order(self) -> bool:
412+
return self._info_proto.nondeterministic_order
413+
414+
def set_nondeterministic_order(self, nondeterministic_order: bool) -> None:
415+
self._info_proto.nondeterministic_order = nondeterministic_order
416+
403417
@property
404418
def homepage(self) -> str:
405419
urls = self.as_proto.location.urls
@@ -923,6 +937,7 @@ def __repr__(self):
923937
("features", _indent(repr(self.features))),
924938
("supervised_keys", self.supervised_keys),
925939
("disable_shuffling", self.disable_shuffling),
940+
("nondeterministic_order", self.nondeterministic_order),
926941
("splits", splits),
927942
("citation", _indent(f'"""{self.citation}"""')),
928943
# Proto add a \n that we strip.
@@ -940,6 +955,7 @@ def __getstate__(self):
940955
"features": self.features,
941956
"supervised_keys": self.supervised_keys,
942957
"disable_shuffling": self.disable_shuffling,
958+
"nondeterministic_order": self.nondeterministic_order,
943959
"homepage": self.homepage,
944960
"citation": self.citation,
945961
"metadata": self.metadata,
@@ -956,6 +972,7 @@ def __setstate__(self, state):
956972
features=state["features"],
957973
supervised_keys=state["supervised_keys"],
958974
disable_shuffling=state["disable_shuffling"],
975+
nondeterministic_order=state["nondeterministic_order"],
959976
homepage=state["homepage"],
960977
citation=state["citation"],
961978
metadata=state["metadata"],

tensorflow_datasets/core/dataset_info_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,7 @@ def test_get_split_info_from_proto_unavailable_format(self):
818818
}),
819819
supervised_keys=('image', 'label'),
820820
disable_shuffling=False,
821+
nondeterministic_order=False,
821822
splits={
822823
'test': <SplitInfo num_examples=20, num_shards=1>,
823824
'train': <SplitInfo num_examples=20, num_shards=1>,

tensorflow_datasets/core/download/download_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ class DownloadConfig:
108108
used.
109109
ignore_duplicates: whether to ignore duplicated examples with the same key.
110110
If there are multiple examples with the same key, the first one is kept.
111+
nondeterministic_order: If True, it will not assure deterministic ordering
112+
when writing' examples to disk in the case of beam datasets. This might
113+
result in quicker dataset preparation.
111114
"""
112115

113116
extract_dir: epath.PathLike | None = None
@@ -126,6 +129,7 @@ class DownloadConfig:
126129
min_shard_size: int = shard_utils.DEFAULT_MIN_SHARD_SIZE
127130
max_shard_size: int = shard_utils.DEFAULT_MAX_SHARD_SIZE
128131
ignore_duplicates: bool = False
132+
nondeterministic_order: bool = False
129133

130134
def get_shard_config(self) -> shard_utils.ShardConfig:
131135
return shard_utils.ShardConfig(

tensorflow_datasets/core/file_adapters.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,17 @@
2626
from typing import Any, ClassVar, Type, TypeVar
2727

2828
from etils import epy
29+
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
2930
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_module
3031
from tensorflow_datasets.core.utils.lazy_imports_utils import parquet as pq
3132
from tensorflow_datasets.core.utils.lazy_imports_utils import pyarrow as pa
3233
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
3334

35+
3436
with epy.lazy_imports():
3537
# pylint: disable=g-import-not-at-top
3638
from etils import epath
39+
from tensorflow_datasets.core import naming
3740
from tensorflow_datasets.core.utils import file_utils
3841
from tensorflow_datasets.core.utils import type_utils
3942

@@ -167,6 +170,23 @@ def deserialize(cls, raw_example: bytes) -> Any:
167170
"""
168171
return tf.train.Example.FromString(raw_example)
169172

173+
@classmethod
174+
def beam_sink(
175+
cls,
176+
filename_template: naming.ShardedFileTemplate,
177+
num_shards: int | None = None,
178+
) -> beam.PTransform:
179+
"""Returns a Beam sink for writing examples in the given file format."""
180+
raise NotImplementedError()
181+
182+
@classmethod
183+
def num_examples(cls, filename: epath.PathLike) -> int:
184+
"""Returns the number of examples in the given file."""
185+
n = 0
186+
for _ in cls.make_tf_data(filename):
187+
n += 1
188+
return n
189+
170190

171191
class TfRecordFileAdapter(FileAdapter):
172192
"""File adapter for TFRecord file format."""
@@ -205,6 +225,20 @@ def write_examples(
205225
writer.write(serialized_example)
206226
writer.flush()
207227

228+
@classmethod
229+
def beam_sink(
230+
cls,
231+
filename_template: naming.ShardedFileTemplate,
232+
num_shards: int | None = None,
233+
) -> beam.PTransform:
234+
"""Returns a Beam sink for writing examples in the given file format."""
235+
file_path_prefix = filename_template.sharded_filepaths_pattern(
236+
num_shards=num_shards, use_at_notation=True
237+
).removesuffix('@*')
238+
return beam.io.WriteToTFRecord(
239+
file_path_prefix=file_path_prefix, num_shards=num_shards
240+
)
241+
208242

209243
class RiegeliFileAdapter(FileAdapter):
210244
"""File adapter for Riegeli file format."""

0 commit comments

Comments
 (0)