Skip to content

Commit 3a34edd

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Add option to ignore duplicates in DownloadConfig
When duplicates are ignored, then only one of the multiple examples with the same keys is kept and no exception is raised. PiperOrigin-RevId: 624137980
1 parent 364fb82 commit 3a34edd

File tree

6 files changed

+120
-39
lines changed

6 files changed

+120
-39
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,6 +1596,7 @@ def _generate_splits(
15961596
beam_runner=download_config.beam_runner,
15971597
shard_config=download_config.get_shard_config(),
15981598
example_writer=self._example_writer(),
1599+
ignore_duplicates=download_config.ignore_duplicates,
15991600
)
16001601
# Wrap the generation inside a context manager.
16011602
# If `beam` is used during generation (when a pipeline gets created),

tensorflow_datasets/core/download/download_manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
"""Download manager interface."""
17+
1718
from __future__ import annotations
1819

1920
import concurrent.futures
@@ -100,6 +101,8 @@ class DownloadConfig:
100101
used.
101102
max_shard_size: optional maximum shard size in bytes. If `None`, 1 GiB is
102103
used.
104+
ignore_duplicates: whether to ignore duplicated examples with the same key.
105+
If there are multiple examples with the same key, the first one is kept.
103106
"""
104107

105108
extract_dir: Optional[epath.PathLike] = None
@@ -117,6 +120,7 @@ class DownloadConfig:
117120
num_shards: Optional[int] = None
118121
min_shard_size: int = shard_utils.DEFAULT_MIN_SHARD_SIZE
119122
max_shard_size: int = shard_utils.DEFAULT_MAX_SHARD_SIZE
123+
ignore_duplicates: bool = False
120124

121125
def get_shard_config(self) -> shard_utils.ShardConfig:
122126
return shard_utils.ShardConfig(
@@ -248,9 +252,7 @@ def __init__(
248252

249253
self._download_dir: epath.Path = download_dir
250254
self._extract_dir: epath.Path = extract_dir
251-
self._manual_dir: Optional[epath.Path] = (
252-
manual_dir # pytype: disable=annotation-type-mismatch # attribute-variable-annotations
253-
)
255+
self._manual_dir: Optional[epath.Path] = manual_dir # pytype: disable=annotation-type-mismatch # attribute-variable-annotations
254256
self._manual_dir_instructions = utils.dedent(manual_dir_instructions)
255257
self._download_dir.mkdir(parents=True, exist_ok=True)
256258
self._extract_dir.mkdir(parents=True, exist_ok=True)

tensorflow_datasets/core/shuffle.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515

1616
"""To shuffle records (stable)."""
1717

18+
from collections.abc import Iterator, Sequence
1819
import math
1920
import os
2021
import struct
21-
from typing import Iterator, List, Optional
22+
from typing import Optional
2223
import uuid
2324

2425
from absl import logging
@@ -213,18 +214,28 @@ def del_file(self):
213214
class Shuffler(object):
214215
"""Stores data in temp buckets, restitute it shuffled."""
215216

216-
def __init__(self, dirpath, hash_salt, disable_shuffling: bool = False):
217+
def __init__(
218+
self,
219+
dirpath,
220+
hash_salt,
221+
disable_shuffling: bool = False,
222+
ignore_duplicates: bool = False,
223+
):
217224
"""Initialize Shuffler.
218225
219226
Args:
220227
dirpath (string): directory in which to store temporary files.
221228
hash_salt (string or bytes): salt to hash keys.
222229
disable_shuffling (bool): specify whether to shuffle by hashing the key.
230+
ignore_duplicates: whether to ignore duplicated examples with the same
231+
key. If there are multiple examples with the same key, the first one is
232+
kept. If this is False, then a `DuplicatedKeysError` is raised.
223233
"""
224234
grp_name = uuid.uuid4()
225235
self._hasher = hashing.Hasher(hash_salt)
226236
self._disable_shuffling = disable_shuffling
227-
self._buckets: List[_Bucket] = []
237+
self._ignore_duplicates = ignore_duplicates
238+
self._buckets: list[_Bucket] = []
228239
for i in range(BUCKETS_NUMBER):
229240
bucket_name = 'bucket_%s_%03d.tmp' % (grp_name, i)
230241
path = os.path.join(dirpath, bucket_name)
@@ -234,47 +245,58 @@ def __init__(self, dirpath, hash_salt, disable_shuffling: bool = False):
234245
# To keep data in memory until enough data has been gathered.
235246
self._in_memory = True
236247
self._mem_buffer = []
248+
self._seen_keys: set[int] = set()
249+
self._num_examples = 0
237250

238251
@property
239-
def size(self):
252+
def size(self) -> int:
240253
"""Return total size in bytes of records (not keys)."""
241254
return self._total_bytes
242255

243256
@property
244-
def bucket_lengths(self):
257+
def bucket_lengths(self) -> Sequence[int]:
245258
if self._in_memory:
246259
return [len(self._mem_buffer)]
247260
return [len(b) for b in self._buckets]
248261

249-
def _add_to_bucket(self, hkey, data):
262+
@property
263+
def num_examples(self) -> int:
264+
return self._num_examples
265+
266+
def _add_to_bucket(self, hkey, data) -> None:
250267
bucket_number = get_bucket_number(hkey=hkey, num_buckets=BUCKETS_NUMBER)
251268
self._buckets[bucket_number].add(hkey, data)
252269

253-
def _add_to_mem_buffer(self, hkey, data):
270+
def _add_to_mem_buffer(self, hkey, data) -> None:
254271
self._mem_buffer.append((hkey, data))
255272
if self._total_bytes > MAX_MEM_BUFFER_SIZE:
256273
for hkey, data in self._mem_buffer:
257274
self._add_to_bucket(hkey, data)
258275
self._mem_buffer = None
259276
self._in_memory = False
260277

261-
def add(self, key, data):
278+
def add(self, key, data) -> bool:
262279
"""Add (key, data) to shuffler."""
263280
if self._read_only:
264281
raise AssertionError('add() cannot be called after __iter__.')
265282
if not isinstance(data, six.binary_type):
266283
raise AssertionError(
267284
'Only bytes (not %s) can be stored in Shuffler!' % (type(data))
268285
)
286+
hkey = self._hasher.hash_key(key)
287+
if self._ignore_duplicates:
288+
if hkey in self._seen_keys:
289+
return
290+
self._seen_keys.add(hkey)
269291
if self._disable_shuffling:
292+
# Use the original key and not the hashed key to maintain the order.
270293
hkey = key
271-
else:
272-
hkey = self._hasher.hash_key(key)
273294
self._total_bytes += len(data)
274295
if self._in_memory:
275296
self._add_to_mem_buffer(hkey, data)
276297
else:
277298
self._add_to_bucket(hkey, data)
299+
self._num_examples += 1
278300

279301
def __iter__(self) -> Iterator[type_utils.KeySerializedExample]:
280302
self._read_only = True

tensorflow_datasets/core/split_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __init__(
132132
max_examples_per_split: int | None,
133133
example_writer: writer_lib.ExampleWriter,
134134
shard_config: shard_utils.ShardConfig | None = None,
135+
ignore_duplicates: bool = False,
135136
):
136137
self._split_dict = split_dict
137138
self._features = features
@@ -143,6 +144,7 @@ def __init__(
143144
self._beam_runner = beam_runner
144145
self._beam_pipeline: Optional['beam.Pipeline'] = None
145146
self._shard_config = shard_config
147+
self._ignore_duplicates = ignore_duplicates
146148
self._example_writer = example_writer
147149

148150
@contextlib.contextmanager
@@ -386,6 +388,7 @@ def _build_from_generator(
386388
disable_shuffling=disable_shuffling,
387389
shard_config=self._shard_config,
388390
example_writer=self._example_writer,
391+
ignore_duplicates=self._ignore_duplicates,
389392
)
390393
for key, example in utils.tqdm(
391394
generator,
@@ -428,6 +431,7 @@ def _build_from_pcollection(
428431
disable_shuffling=disable_shuffling,
429432
shard_config=self._shard_config,
430433
example_writer=self._example_writer,
434+
ignore_duplicates=self._ignore_duplicates,
431435
)
432436

433437
def _encode_example(key_ex, encode_fn=self._features.encode_example):

tensorflow_datasets/core/writer.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,19 @@ def _raise_error_for_duplicated_keys(example1, example2, example_specs):
7373
"""Log information about the examples and raise an AssertionError."""
7474
msg = "Two examples share the same hashed key!"
7575
logging.error(msg)
76-
parser = example_parser.ExampleParser(example_specs)
77-
ex1 = parser.parse_example(example1)
78-
ex2 = parser.parse_example(example2)
79-
logging.error("1st example: %s", ex1)
80-
logging.error("2nd example: %s", ex2)
76+
try:
77+
parser = example_parser.ExampleParser(example_specs)
78+
ex1 = parser.parse_example(example1)
79+
ex2 = parser.parse_example(example2)
80+
logging.error("1st example: %s", ex1)
81+
logging.error("2nd example: %s", ex2)
82+
except ValueError:
83+
logging.error(
84+
"Failed to parse examples! Cannot log them to see the examples behind"
85+
" the duplicated keys. Raw example 1: %s, raw example 2: %s",
86+
example1,
87+
example2,
88+
)
8189
raise AssertionError(msg + " See logs above to view the examples.")
8290

8391

@@ -192,6 +200,7 @@ def __init__(
192200
disable_shuffling: bool,
193201
example_writer: ExampleWriter,
194202
shard_config: shard_utils.ShardConfig | None = None,
203+
ignore_duplicates: bool = False,
195204
):
196205
"""Initializes Writer.
197206
@@ -202,14 +211,16 @@ def __init__(
202211
disable_shuffling (bool): Specifies whether to shuffle the records.
203212
example_writer: class that writes examples to disk or elsewhere.
204213
shard_config: the configuration for creating shards.
214+
ignore_duplicates: whether to ignore duplicated examples with the same
215+
key. If False, a `DuplicatedKeysError` will be raised on duplicates.
205216
"""
206217
self._serializer = serializer
207218
self._shuffler = shuffle.Shuffler(
208219
dirpath=filename_template.data_dir,
209220
hash_salt=hash_salt,
210221
disable_shuffling=disable_shuffling,
222+
ignore_duplicates=ignore_duplicates,
211223
)
212-
self._num_examples = 0
213224
self._filename_template = filename_template
214225
self._shard_config = shard_config or shard_utils.ShardConfig()
215226
self._example_writer = example_writer
@@ -226,13 +237,12 @@ def write(self, key: int | bytes, example: Example):
226237
"""
227238
serialized_example = self._serializer.serialize_example(example=example)
228239
self._shuffler.add(key, serialized_example)
229-
self._num_examples += 1
230240

231241
def finalize(self) -> tuple[list[int], int]:
232242
"""Effectively writes examples to the shards."""
233243
filename = self._filename_template.sharded_filepaths_pattern()
234244
shard_specs = _get_shard_specs(
235-
num_examples=self._num_examples,
245+
num_examples=self._shuffler.num_examples,
236246
total_size=self._shuffler.size,
237247
bucket_lengths=self._shuffler.bucket_lengths,
238248
filename_template=self._filename_template,
@@ -245,7 +255,7 @@ def finalize(self) -> tuple[list[int], int]:
245255
utils.tqdm(
246256
self._shuffler,
247257
desc=f"Shuffling {filename}...",
248-
total=self._num_examples,
258+
total=self._shuffler.num_examples,
249259
unit=" examples",
250260
leave=False,
251261
mininterval=1.0,
@@ -322,6 +332,7 @@ def __init__(
322332
disable_shuffling: bool,
323333
example_writer: ExampleWriter,
324334
shard_config: shard_utils.ShardConfig | None = None,
335+
ignore_duplicates: bool = False,
325336
):
326337
"""Init BeamWriter.
327338
@@ -336,6 +347,8 @@ def __init__(
336347
disable_shuffling: bool, specifies whether to shuffle the records.
337348
example_writer: class that writes examples to storage.
338349
shard_config: the configuration for creating shards.
350+
ignore_duplicates: whether to ignore duplicated examples with the same
351+
key. If False, a `DuplicatedKeysError` will be raised on duplicates.
339352
"""
340353
self._original_state = dict(
341354
serializer=serializer,
@@ -344,6 +357,7 @@ def __init__(
344357
disable_shuffling=disable_shuffling,
345358
shard_config=shard_config,
346359
example_writer=example_writer,
360+
ignore_duplicates=ignore_duplicates,
347361
)
348362
self._filename_template = filename_template
349363
self._split_info_path = (
@@ -355,6 +369,7 @@ def __init__(
355369
self._disable_shuffling = disable_shuffling
356370
self._shard_config = shard_config or shard_utils.ShardConfig()
357371
self._example_writer = example_writer
372+
self._ignore_duplicates = ignore_duplicates
358373

359374
@functools.lru_cache()
360375
def _get_counter(self, name: str, namespace: str = "BeamWriter"):
@@ -416,29 +431,34 @@ def _write_final_shard(
416431
raise AssertionError("Not a single example present in the PCollection!")
417432
# There may be empty shards, this ensures there are no gaps.
418433
shard_id = non_empty_shard_ids.index(original_shard_id)
419-
examples = sorted(examples)
420-
self._get_distribution(name="ShardLenDistribution").update(len(examples))
421-
# Compare continuous examples
422-
for ex0, ex1 in zip(examples[:-1], examples[1:]):
423-
if ex0[0] == ex1[0]: # Different keys
424-
_raise_error_for_duplicated_keys(
425-
ex0[1], ex1[1], self._serializer.example_specs
426-
)
434+
example_by_key = {}
435+
for key, example in examples:
436+
if key in example_by_key:
437+
if not self._ignore_duplicates:
438+
_raise_error_for_duplicated_keys(
439+
example_by_key[key], example, self._serializer.example_specs
440+
)
441+
else:
442+
example_by_key[key] = example
427443
shard_path = self._filename_template.sharded_filepath(
428444
shard_index=shard_id, num_shards=len(non_empty_shard_ids)
429445
)
430446
with utils.incomplete_file(epath.Path(shard_path)) as tmp_path:
431447
logging.info(
432-
"Writing %d examples to %s.", len(examples), os.fspath(tmp_path)
448+
"Writing %d examples to %s.", len(example_by_key), os.fspath(tmp_path)
449+
)
450+
record_keys = self._example_writer.write(
451+
tmp_path, sorted(example_by_key.items())
433452
)
434-
record_keys = self._example_writer.write(tmp_path, examples)
435453
self.inc_counter(name="written_shards")
436454
# If there are record_keys, create index files.
437455
if record_keys:
438456
index_path = _get_index_path(os.fspath(shard_path))
439457
_write_index_file(index_path, list(record_keys))
440-
shard_size = sum(map(len, examples))
441-
return _ShardInfo(id=shard_id, num_examples=len(examples), size=shard_size)
458+
shard_size = sum(map(len, example_by_key.values()))
459+
return _ShardInfo(
460+
id=shard_id, num_examples=len(example_by_key), size=shard_size
461+
)
442462

443463
def _number_of_shards(self, num_examples: int, total_size: int) -> int:
444464
"""Returns the number of shards."""
@@ -468,11 +488,11 @@ def _assign_shard(
468488
def _store_split_info(
469489
self,
470490
shard_infos: Sequence[_ShardInfo],
471-
total_size: int,
472491
) -> None:
473492
"""Stores the split info to disk."""
474493
shard_infos = sorted(shard_infos, key=lambda x: x.id)
475494
shard_lengths = [info.num_examples for info in shard_infos]
495+
total_size = sum([info.size for info in shard_infos])
476496
with utils.incomplete_file(epath.Path(self._split_info_path)) as tmp_path:
477497
tmp_path.write_text(
478498
json.dumps({"total_size": total_size, "shard_lengths": shard_lengths})
@@ -553,8 +573,7 @@ def write_from_pcollection(self, examples_pcollection):
553573
# (_ShardInfo)
554574
| "CollectShardInfo" >> beam.transforms.combiners.ToList()
555575
# [_ShardInfo]
556-
| "CalculateSplitInfo"
557-
>> beam.ParDo(self._store_split_info, total_size=total_size)
576+
| "CalculateSplitInfo" >> beam.ParDo(self._store_split_info)
558577
)
559578

560579
def finalize(self) -> tuple[list[int], int]:

0 commit comments

Comments
 (0)