Skip to content

Commit a38292f

Browse files
author
The TensorFlow Datasets Authors
committed
Add a beam writer that doesn't shuffle
PiperOrigin-RevId: 691657172
1 parent fb68321 commit a38292f

14 files changed

+136
-409
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -623,10 +623,6 @@ def download_and_prepare(
623623
data_path = self.data_path
624624
data_exists = data_path.exists()
625625

626-
# Saving nondeterministic_order in the DatasetInfo for documentation.
627-
if download_config.nondeterministic_order:
628-
self.info.set_nondeterministic_order(True)
629-
630626
if download_config.download_mode == UPDATE_DATASET_INFO:
631627
self._update_dataset_info()
632628
return
@@ -1431,13 +1427,11 @@ def _get_filename_template(
14311427
self, split_name: str
14321428
) -> naming.ShardedFileTemplate:
14331429
"""Returns a filename template for the given split."""
1434-
if self.info.file_format is None:
1435-
raise ValueError("File format is not set!")
14361430
return naming.ShardedFileTemplate(
14371431
split=split_name,
14381432
dataset_name=self.name,
14391433
data_dir=self.data_path,
1440-
filetype_suffix=self.info.file_format.file_suffix,
1434+
filetype_suffix=self.info.file_format.file_suffix, # pytype: disable=attribute-error
14411435
)
14421436

14431437

@@ -1735,7 +1729,6 @@ def _generate_splits(
17351729
generator=generator,
17361730
filename_template=filename_template,
17371731
disable_shuffling=self.info.disable_shuffling,
1738-
nondeterministic_order=download_config.nondeterministic_order,
17391732
)
17401733
split_info_futures.append(future)
17411734

tensorflow_datasets/core/dataset_builder_beam_test.py

Lines changed: 22 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,14 @@ class DummyBeamDataset(dataset_builder.GeneratorBasedBuilder):
3939
'valid_725': 725,
4040
}
4141

42-
FEATURE_DICT = features.FeaturesDict({
43-
'image': features.Image(shape=(16, 16, 1)),
44-
'label': features.ClassLabel(names=['dog', 'cat']),
45-
'id': tf.int32,
46-
})
47-
4842
def _info(self):
4943
return dataset_info.DatasetInfo(
5044
builder=self,
51-
features=self.FEATURE_DICT,
45+
features=features.FeaturesDict({
46+
'image': features.Image(shape=(16, 16, 1)),
47+
'label': features.ClassLabel(names=['dog', 'cat']),
48+
'id': tf.int32,
49+
}),
5250
supervised_keys=('x', 'x'),
5351
metadata=dataset_info.BeamMetadataDict(),
5452
)
@@ -73,18 +71,6 @@ def _generate_examples(self, num_examples):
7371
return examples
7472

7573

76-
class UnshuffledDummyBeamDataset(DummyBeamDataset):
77-
78-
def _info(self) -> dataset_info.DatasetInfo:
79-
return dataset_info.DatasetInfo(
80-
builder=self,
81-
features=self.FEATURE_DICT,
82-
supervised_keys=('x', 'x'),
83-
metadata=dataset_info.BeamMetadataDict(),
84-
disable_shuffling=True,
85-
)
86-
87-
8874
class CommonPipelineDummyBeamDataset(DummyBeamDataset):
8975
EXPECTED_METADATA = {
9076
'label_sum_1000': 500,
@@ -165,21 +151,12 @@ def _compute_mean(examples):
165151
)
166152

167153

168-
def get_id(ex):
169-
return ex['id']
170-
171-
172154
def make_default_config():
173155
return download.DownloadConfig()
174156

175157

176158
@pytest.mark.parametrize(
177-
'dataset_cls',
178-
[
179-
DummyBeamDataset,
180-
CommonPipelineDummyBeamDataset,
181-
UnshuffledDummyBeamDataset,
182-
],
159+
'dataset_cls', [DummyBeamDataset, CommonPipelineDummyBeamDataset]
183160
)
184161
@pytest.mark.parametrize(
185162
'make_dl_config',
@@ -201,23 +178,29 @@ def test_beam_datasets(
201178
assert data_path.exists() # Dataset has been generated
202179

203180
# Check number of shards/generated files
204-
for split in ['test', 'train']:
205-
_test_shards(
206-
data_path,
207-
pattern='%s-%s.tfrecord-{:05}-of-{:05}' % (dataset_name, split),
208-
num_shards=builder.info.splits[split].num_shards,
209-
)
181+
_test_shards(
182+
data_path,
183+
pattern='%s-test.tfrecord-{:05}-of-{:05}' % dataset_name,
184+
# Liquid sharding is not guaranteed to always use the same number.
185+
num_shards=builder.info.splits['test'].num_shards,
186+
)
187+
_test_shards(
188+
data_path,
189+
pattern='%s-train.tfrecord-{:05}-of-{:05}' % dataset_name,
190+
num_shards=1,
191+
)
210192

211193
ds = dataset_utils.as_numpy(builder.as_dataset())
212194

213-
test_examples = list(ds['test'])
214-
train_examples = list(ds['train'])
195+
def get_id(ex):
196+
return ex['id']
197+
215198
_assert_values_equal(
216-
sorted(test_examples, key=get_id),
199+
sorted(list(ds['test']), key=get_id),
217200
sorted([_gen_example(i)[1] for i in range(725)], key=get_id),
218201
)
219202
_assert_values_equal(
220-
sorted(train_examples, key=get_id),
203+
sorted(list(ds['train']), key=get_id),
221204
sorted([_gen_example(i)[1] for i in range(1000)], key=get_id),
222205
)
223206

tensorflow_datasets/core/dataset_info.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ def __init__(
186186
features: feature_lib.FeatureConnector | None = None,
187187
supervised_keys: SupervisedKeysType | None = None,
188188
disable_shuffling: bool = False,
189-
nondeterministic_order: bool = False,
190189
homepage: str | None = None,
191190
citation: str | None = None,
192191
metadata: Metadata | None = None,
@@ -229,11 +228,7 @@ def __init__(
229228
230229
Note that selecting features in nested `tfds.features.FeaturesDict`
231230
objects is not supported.
232-
disable_shuffling: `bool`, specifies whether to shuffle the examples.
233-
nondeterministic_order: `bool`, if True and the dataset uses beam, it will
234-
use `NoShuffleBeamWriter` which does not assure deterministic
235-
shuffling when writing' examples to disk. This might result in quicker
236-
dataset preparation.
231+
disable_shuffling: `bool`, specify whether to shuffle the examples.
237232
homepage: `str`, optional, the homepage for this dataset.
238233
citation: `str`, optional, the citation to use for this dataset.
239234
metadata: `tfds.core.Metadata`, additonal object which will be
@@ -273,7 +268,6 @@ def __init__(
273268
version=str(self._identity.version),
274269
release_notes=self._identity.release_notes,
275270
disable_shuffling=disable_shuffling,
276-
nondeterministic_order=nondeterministic_order,
277271
config_name=self._identity.config_name,
278272
config_description=self._identity.config_description,
279273
config_tags=self._identity.config_tags,
@@ -348,7 +342,6 @@ def from_proto(
348342
features=features,
349343
supervised_keys=supervised_keys,
350344
disable_shuffling=proto.disable_shuffling,
351-
nondeterministic_order=proto.nondeterministic_order,
352345
citation=proto.citation,
353346
license=proto.redistribution_info.license,
354347
split_dict=splits_lib.SplitDict.from_proto(
@@ -407,13 +400,6 @@ def release_notes(self) -> dict[str, str] | None:
407400
def disable_shuffling(self) -> bool:
408401
return self.as_proto.disable_shuffling
409402

410-
@property
411-
def nondeterministic_order(self) -> bool:
412-
return self._info_proto.nondeterministic_order
413-
414-
def set_nondeterministic_order(self, nondeterministic_order: bool) -> None:
415-
self._info_proto.nondeterministic_order = nondeterministic_order
416-
417403
@property
418404
def homepage(self) -> str:
419405
urls = self.as_proto.location.urls
@@ -937,7 +923,6 @@ def __repr__(self):
937923
("features", _indent(repr(self.features))),
938924
("supervised_keys", self.supervised_keys),
939925
("disable_shuffling", self.disable_shuffling),
940-
("nondeterministic_order", self.nondeterministic_order),
941926
("splits", splits),
942927
("citation", _indent(f'"""{self.citation}"""')),
943928
# Proto add a \n that we strip.
@@ -955,7 +940,6 @@ def __getstate__(self):
955940
"features": self.features,
956941
"supervised_keys": self.supervised_keys,
957942
"disable_shuffling": self.disable_shuffling,
958-
"nondeterministic_order": self.nondeterministic_order,
959943
"homepage": self.homepage,
960944
"citation": self.citation,
961945
"metadata": self.metadata,
@@ -972,7 +956,6 @@ def __setstate__(self, state):
972956
features=state["features"],
973957
supervised_keys=state["supervised_keys"],
974958
disable_shuffling=state["disable_shuffling"],
975-
nondeterministic_order=state["nondeterministic_order"],
976959
homepage=state["homepage"],
977960
citation=state["citation"],
978961
metadata=state["metadata"],

tensorflow_datasets/core/dataset_info_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,6 @@ def test_get_split_info_from_proto_unavailable_format(self):
818818
}),
819819
supervised_keys=('image', 'label'),
820820
disable_shuffling=False,
821-
nondeterministic_order=False,
822821
splits={
823822
'test': <SplitInfo num_examples=20, num_shards=1>,
824823
'train': <SplitInfo num_examples=20, num_shards=1>,

tensorflow_datasets/core/download/download_manager.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,6 @@ class DownloadConfig:
108108
used.
109109
ignore_duplicates: whether to ignore duplicated examples with the same key.
110110
If there are multiple examples with the same key, the first one is kept.
111-
nondeterministic_order: If True, it will not assure deterministic ordering
112-
when writing' examples to disk in the case of beam datasets. This might
113-
result in quicker dataset preparation.
114111
"""
115112

116113
extract_dir: epath.PathLike | None = None
@@ -129,7 +126,6 @@ class DownloadConfig:
129126
min_shard_size: int = shard_utils.DEFAULT_MIN_SHARD_SIZE
130127
max_shard_size: int = shard_utils.DEFAULT_MAX_SHARD_SIZE
131128
ignore_duplicates: bool = False
132-
nondeterministic_order: bool = False
133129

134130
def get_shard_config(self) -> shard_utils.ShardConfig:
135131
return shard_utils.ShardConfig(

tensorflow_datasets/core/file_adapters.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,14 @@
2626
from typing import Any, ClassVar, Type, TypeVar
2727

2828
from etils import epy
29-
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
3029
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_module
3130
from tensorflow_datasets.core.utils.lazy_imports_utils import parquet as pq
3231
from tensorflow_datasets.core.utils.lazy_imports_utils import pyarrow as pa
3332
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
3433

35-
3634
with epy.lazy_imports():
3735
# pylint: disable=g-import-not-at-top
3836
from etils import epath
39-
from tensorflow_datasets.core import naming
4037
from tensorflow_datasets.core.utils import file_utils
4138
from tensorflow_datasets.core.utils import type_utils
4239

@@ -170,23 +167,6 @@ def deserialize(cls, raw_example: bytes) -> Any:
170167
"""
171168
return tf.train.Example.FromString(raw_example)
172169

173-
@classmethod
174-
def beam_sink(
175-
cls,
176-
filename_template: naming.ShardedFileTemplate,
177-
num_shards: int | None = None,
178-
) -> beam.PTransform:
179-
"""Returns a Beam sink for writing examples in the given file format."""
180-
raise NotImplementedError()
181-
182-
@classmethod
183-
def num_examples(cls, filename: epath.PathLike) -> int:
184-
"""Returns the number of examples in the given file."""
185-
n = 0
186-
for _ in cls.make_tf_data(filename):
187-
n += 1
188-
return n
189-
190170

191171
class TfRecordFileAdapter(FileAdapter):
192172
"""File adapter for TFRecord file format."""
@@ -225,20 +205,6 @@ def write_examples(
225205
writer.write(serialized_example)
226206
writer.flush()
227207

228-
@classmethod
229-
def beam_sink(
230-
cls,
231-
filename_template: naming.ShardedFileTemplate,
232-
num_shards: int | None = None,
233-
) -> beam.PTransform:
234-
"""Returns a Beam sink for writing examples in the given file format."""
235-
file_path_prefix = filename_template.sharded_filepaths_pattern(
236-
num_shards=num_shards, use_at_notation=True
237-
).removesuffix('@*')
238-
return beam.io.WriteToTFRecord(
239-
file_path_prefix=file_path_prefix, num_shards=num_shards
240-
)
241-
242208

243209
class RiegeliFileAdapter(FileAdapter):
244210
"""File adapter for Riegeli file format."""

0 commit comments

Comments
 (0)