21
21
import struct
22
22
from typing import Optional
23
23
import uuid
24
-
25
24
from absl import logging
26
- import six
25
+ from etils import epath
27
26
from tensorflow_datasets .core import hashing
28
27
from tensorflow_datasets .core .utils import file_utils
29
28
from tensorflow_datasets .core .utils import type_utils
@@ -57,14 +56,14 @@ def __init__(self, item1, item2):
57
56
self .item2 = item2
58
57
59
58
60
- def _hkey_to_bytes (hkey ) :
59
+ def _hkey_to_bytes (hkey : int ) -> bytes :
61
60
"""Converts 128 bits integer hkey to binary representation."""
62
61
max_int64 = 0xFFFFFFFFFFFFFFFF
63
62
return struct .pack ('=QQ' , (hkey >> 64 ) & max_int64 , hkey & max_int64 )
64
63
65
64
66
- def _read_hkey (buff ) :
67
- """Reads from fobj and returns hkey (128 bites integer)."""
65
+ def _read_hkey (buff : bytes ) -> int :
66
+ """Reads from fobj and returns hkey (128 bits integer)."""
68
67
a , b = struct .unpack ('=QQ' , buff )
69
68
return (a << 64 ) | b
70
69
@@ -99,7 +98,7 @@ def _increase_open_files_limit():
99
98
100
99
101
100
def get_bucket_number (
102
- hkey ,
101
+ hkey : int ,
103
102
num_buckets : int ,
104
103
max_hkey : Optional [int ] = None ,
105
104
) -> int :
@@ -130,25 +129,25 @@ class _Bucket(object):
130
129
...
131
130
"""
132
131
133
- def __init__ (self , path ):
132
+ def __init__ (self , path : epath . Path ):
134
133
"""Initialize a _Bucket instance.
135
134
136
135
Args:
137
- path (str): path to bucket file, where to write to or read from.
136
+ path: Path to bucket file, where to write to or read from.
138
137
"""
139
138
self ._path = path
140
139
self ._fobj = None
141
140
self ._length = 0
142
141
self ._size = 0
143
142
144
143
@property
145
- def size (self ):
144
+ def size (self ) -> int :
146
145
return self ._size
147
146
148
- def __len__ (self ):
147
+ def __len__ (self ) -> int :
149
148
return self ._length
150
149
151
- def add (self , key , data ):
150
+ def add (self , key : type_utils . Key , data : bytes ):
152
151
"""Adds (key, data) to bucket.
153
152
154
153
Args:
@@ -216,18 +215,18 @@ class Shuffler(object):
216
215
217
216
def __init__ (
218
217
self ,
219
- dirpath ,
220
- hash_salt ,
218
+ dirpath : epath . PathLike ,
219
+ hash_salt : str | bytes ,
221
220
disable_shuffling : bool = False ,
222
221
ignore_duplicates : bool = False ,
223
222
):
224
223
"""Initialize Shuffler.
225
224
226
225
Args:
227
- dirpath (string): directory in which to store temporary files.
228
- hash_salt (string or bytes): salt to hash keys.
229
- disable_shuffling (bool): specify whether to shuffle by hashing the key.
230
- ignore_duplicates: whether to ignore duplicated examples with the same
226
+ dirpath: Path to the directory in which to store temporary files.
227
+ hash_salt: Salt to hash keys.
228
+ disable_shuffling: Specifies whether to shuffle by hashing the key.
229
+ ignore_duplicates: Whether to ignore duplicated examples with the same
231
230
key. If there are multiple examples with the same key, the first one is
232
231
kept. If this is False, then a `DuplicatedKeysError` is raised.
233
232
"""
@@ -238,7 +237,7 @@ def __init__(
238
237
self ._buckets : list [_Bucket ] = []
239
238
for i in range (BUCKETS_NUMBER ):
240
239
bucket_name = 'bucket_%s_%03d.tmp' % (grp_name , i )
241
- path = os . path . join (dirpath , bucket_name )
240
+ path = epath . Path (dirpath ) / bucket_name
242
241
self ._buckets .append (_Bucket (path ))
243
242
self ._read_only = False
244
243
self ._total_bytes = 0
@@ -263,25 +262,25 @@ def bucket_lengths(self) -> Sequence[int]:
263
262
def num_examples (self ) -> int :
264
263
return self ._num_examples
265
264
266
- def _add_to_bucket (self , hkey , data ) -> None :
265
+ def _add_to_bucket (self , hkey : int , data : bytes ) -> None :
267
266
bucket_number = get_bucket_number (hkey = hkey , num_buckets = BUCKETS_NUMBER )
268
267
self ._buckets [bucket_number ].add (hkey , data )
269
268
270
- def _add_to_mem_buffer (self , hkey , data ) -> None :
269
+ def _add_to_mem_buffer (self , hkey : int , data : bytes ) -> None :
271
270
self ._mem_buffer .append ((hkey , data ))
272
271
if self ._total_bytes > MAX_MEM_BUFFER_SIZE :
273
272
for hkey , data in self ._mem_buffer :
274
273
self ._add_to_bucket (hkey , data )
275
274
self ._mem_buffer = None
276
275
self ._in_memory = False
277
276
278
- def add (self , key , data ) -> bool :
277
+ def add (self , key : type_utils . Key , data : bytes ) -> bool :
279
278
"""Add (key, data) to shuffler."""
280
279
if self ._read_only :
281
280
raise AssertionError ('add() cannot be called after __iter__.' )
282
- if not isinstance (data , six . binary_type ):
281
+ if not isinstance (data , bytes ):
283
282
raise AssertionError (
284
- 'Only bytes (not %s) can be stored in Shuffler!' % ( type ( data ))
283
+ f 'Only bytes (not { type ( data ) } ) can be stored in Shuffler!'
285
284
)
286
285
hkey = self ._hasher .hash_key (key )
287
286
if self ._ignore_duplicates :
0 commit comments