Skip to content

Commit d2ad852

Browse files
author
The TensorFlow Datasets Authors
committed
Add types and errors to shuffler
PiperOrigin-RevId: 703402998
1 parent 9d6090d commit d2ad852

File tree

3 files changed

+36
-30
lines changed

3 files changed

+36
-30
lines changed

tensorflow_datasets/core/shuffle.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
import struct
2222
from typing import Optional
2323
import uuid
24-
2524
from absl import logging
26-
import six
25+
from etils import epath
2726
from tensorflow_datasets.core import hashing
2827
from tensorflow_datasets.core.utils import file_utils
2928
from tensorflow_datasets.core.utils import type_utils
@@ -57,14 +56,14 @@ def __init__(self, item1, item2):
5756
self.item2 = item2
5857

5958

60-
def _hkey_to_bytes(hkey):
59+
def _hkey_to_bytes(hkey: int) -> bytes:
6160
"""Converts 128 bits integer hkey to binary representation."""
6261
max_int64 = 0xFFFFFFFFFFFFFFFF
6362
return struct.pack('=QQ', (hkey >> 64) & max_int64, hkey & max_int64)
6463

6564

66-
def _read_hkey(buff):
67-
"""Reads from fobj and returns hkey (128 bites integer)."""
65+
def _read_hkey(buff: bytes) -> int:
66+
"""Reads from fobj and returns hkey (128 bits integer)."""
6867
a, b = struct.unpack('=QQ', buff)
6968
return (a << 64) | b
7069

@@ -99,7 +98,7 @@ def _increase_open_files_limit():
9998

10099

101100
def get_bucket_number(
102-
hkey,
101+
hkey: int,
103102
num_buckets: int,
104103
max_hkey: Optional[int] = None,
105104
) -> int:
@@ -130,25 +129,25 @@ class _Bucket(object):
130129
...
131130
"""
132131

133-
def __init__(self, path):
132+
def __init__(self, path: epath.Path):
134133
"""Initialize a _Bucket instance.
135134
136135
Args:
137-
path (str): path to bucket file, where to write to or read from.
136+
path: Path to bucket file, where to write to or read from.
138137
"""
139138
self._path = path
140139
self._fobj = None
141140
self._length = 0
142141
self._size = 0
143142

144143
@property
145-
def size(self):
144+
def size(self) -> int:
146145
return self._size
147146

148-
def __len__(self):
147+
def __len__(self) -> int:
149148
return self._length
150149

151-
def add(self, key, data):
150+
def add(self, key: type_utils.Key, data: bytes):
152151
"""Adds (key, data) to bucket.
153152
154153
Args:
@@ -216,18 +215,18 @@ class Shuffler(object):
216215

217216
def __init__(
218217
self,
219-
dirpath,
220-
hash_salt,
218+
dirpath: epath.PathLike,
219+
hash_salt: str | bytes,
221220
disable_shuffling: bool = False,
222221
ignore_duplicates: bool = False,
223222
):
224223
"""Initialize Shuffler.
225224
226225
Args:
227-
dirpath (string): directory in which to store temporary files.
228-
hash_salt (string or bytes): salt to hash keys.
229-
disable_shuffling (bool): specify whether to shuffle by hashing the key.
230-
ignore_duplicates: whether to ignore duplicated examples with the same
226+
dirpath: Path to the directory in which to store temporary files.
227+
hash_salt: Salt to hash keys.
228+
disable_shuffling: Specifies whether to shuffle by hashing the key.
229+
ignore_duplicates: Whether to ignore duplicated examples with the same
231230
key. If there are multiple examples with the same key, the first one is
232231
kept. If this is False, then a `DuplicatedKeysError` is raised.
233232
"""
@@ -238,7 +237,7 @@ def __init__(
238237
self._buckets: list[_Bucket] = []
239238
for i in range(BUCKETS_NUMBER):
240239
bucket_name = 'bucket_%s_%03d.tmp' % (grp_name, i)
241-
path = os.path.join(dirpath, bucket_name)
240+
path = epath.Path(dirpath) / bucket_name
242241
self._buckets.append(_Bucket(path))
243242
self._read_only = False
244243
self._total_bytes = 0
@@ -263,25 +262,25 @@ def bucket_lengths(self) -> Sequence[int]:
263262
def num_examples(self) -> int:
264263
return self._num_examples
265264

266-
def _add_to_bucket(self, hkey, data) -> None:
265+
def _add_to_bucket(self, hkey: int, data: bytes) -> None:
267266
bucket_number = get_bucket_number(hkey=hkey, num_buckets=BUCKETS_NUMBER)
268267
self._buckets[bucket_number].add(hkey, data)
269268

270-
def _add_to_mem_buffer(self, hkey, data) -> None:
269+
def _add_to_mem_buffer(self, hkey: int, data: bytes) -> None:
271270
self._mem_buffer.append((hkey, data))
272271
if self._total_bytes > MAX_MEM_BUFFER_SIZE:
273272
for hkey, data in self._mem_buffer:
274273
self._add_to_bucket(hkey, data)
275274
self._mem_buffer = None
276275
self._in_memory = False
277276

278-
def add(self, key, data) -> bool:
277+
def add(self, key: type_utils.Key, data: bytes) -> bool:
279278
"""Add (key, data) to shuffler."""
280279
if self._read_only:
281280
raise AssertionError('add() cannot be called after __iter__.')
282-
if not isinstance(data, six.binary_type):
281+
if not isinstance(data, bytes):
283282
raise AssertionError(
284-
'Only bytes (not %s) can be stored in Shuffler!' % (type(data))
283+
f'Only bytes (not {type(data)}) can be stored in Shuffler!'
285284
)
286285
hkey = self._hasher.hash_key(key)
287286
if self._ignore_duplicates:

tensorflow_datasets/core/split_builder.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -505,18 +505,23 @@ def _build_from_generator(
505505
example_writer=self._example_writer,
506506
ignore_duplicates=self._ignore_duplicates,
507507
)
508-
for key, example in utils.tqdm(
509-
generator,
510-
desc=f'Generating {split_name} examples...',
511-
unit=' examples',
512-
total=total_num_examples,
513-
leave=False,
514-
mininterval=1.0,
508+
for i, (key, example) in enumerate(
509+
utils.tqdm(
510+
generator,
511+
desc=f'Generating {split_name} examples...',
512+
unit=' examples',
513+
total=total_num_examples,
514+
leave=False,
515+
mininterval=1.0,
516+
)
515517
):
516518
try:
517519
example = self._features.encode_example(example)
518520
except Exception as e: # pylint: disable=broad-except
519521
utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n')
522+
if disable_shuffling and not isinstance(key, int):
523+
# If `disable_shuffling` is set to True, the key must be an integer.
524+
key = i
520525
writer.write(key, example)
521526
try:
522527
shard_lengths, total_size = writer.finalize()

tensorflow_datasets/core/split_builder_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import psutil
2525
import pytest
2626
from tensorflow_datasets import testing
27+
from tensorflow_datasets.core import file_adapters
28+
from tensorflow_datasets.core import naming
2729
from tensorflow_datasets.core import split_builder as split_builder_lib
2830
from tensorflow_datasets.core import writer as writer_lib
2931

0 commit comments

Comments
 (0)