@@ -73,11 +73,19 @@ def _raise_error_for_duplicated_keys(example1, example2, example_specs):
73
73
"""Log information about the examples and raise an AssertionError."""
74
74
msg = "Two examples share the same hashed key!"
75
75
logging .error (msg )
76
- parser = example_parser .ExampleParser (example_specs )
77
- ex1 = parser .parse_example (example1 )
78
- ex2 = parser .parse_example (example2 )
79
- logging .error ("1st example: %s" , ex1 )
80
- logging .error ("2nd example: %s" , ex2 )
76
+ try :
77
+ parser = example_parser .ExampleParser (example_specs )
78
+ ex1 = parser .parse_example (example1 )
79
+ ex2 = parser .parse_example (example2 )
80
+ logging .error ("1st example: %s" , ex1 )
81
+ logging .error ("2nd example: %s" , ex2 )
82
+ except ValueError :
83
+ logging .error (
84
+ "Failed to parse examples! Cannot log them to see the examples behind"
85
+ " the duplicated keys. Raw example 1: %s, raw example 2: %s" ,
86
+ example1 ,
87
+ example2 ,
88
+ )
81
89
raise AssertionError (msg + " See logs above to view the examples." )
82
90
83
91
@@ -192,6 +200,7 @@ def __init__(
192
200
disable_shuffling : bool ,
193
201
example_writer : ExampleWriter ,
194
202
shard_config : shard_utils .ShardConfig | None = None ,
203
+ ignore_duplicates : bool = False ,
195
204
):
196
205
"""Initializes Writer.
197
206
@@ -202,14 +211,16 @@ def __init__(
202
211
disable_shuffling (bool): Specifies whether to shuffle the records.
203
212
example_writer: class that writes examples to disk or elsewhere.
204
213
shard_config: the configuration for creating shards.
214
+ ignore_duplicates: whether to ignore duplicated examples with the same
215
+ key. If False, a `DuplicatedKeysError` will be raised on duplicates.
205
216
"""
206
217
self ._serializer = serializer
207
218
self ._shuffler = shuffle .Shuffler (
208
219
dirpath = filename_template .data_dir ,
209
220
hash_salt = hash_salt ,
210
221
disable_shuffling = disable_shuffling ,
222
+ ignore_duplicates = ignore_duplicates ,
211
223
)
212
- self ._num_examples = 0
213
224
self ._filename_template = filename_template
214
225
self ._shard_config = shard_config or shard_utils .ShardConfig ()
215
226
self ._example_writer = example_writer
@@ -226,13 +237,12 @@ def write(self, key: int | bytes, example: Example):
226
237
"""
227
238
serialized_example = self ._serializer .serialize_example (example = example )
228
239
self ._shuffler .add (key , serialized_example )
229
- self ._num_examples += 1
230
240
231
241
def finalize (self ) -> tuple [list [int ], int ]:
232
242
"""Effectively writes examples to the shards."""
233
243
filename = self ._filename_template .sharded_filepaths_pattern ()
234
244
shard_specs = _get_shard_specs (
235
- num_examples = self ._num_examples ,
245
+ num_examples = self ._shuffler . num_examples ,
236
246
total_size = self ._shuffler .size ,
237
247
bucket_lengths = self ._shuffler .bucket_lengths ,
238
248
filename_template = self ._filename_template ,
@@ -245,7 +255,7 @@ def finalize(self) -> tuple[list[int], int]:
245
255
utils .tqdm (
246
256
self ._shuffler ,
247
257
desc = f"Shuffling { filename } ..." ,
248
- total = self ._num_examples ,
258
+ total = self ._shuffler . num_examples ,
249
259
unit = " examples" ,
250
260
leave = False ,
251
261
mininterval = 1.0 ,
@@ -322,6 +332,7 @@ def __init__(
322
332
disable_shuffling : bool ,
323
333
example_writer : ExampleWriter ,
324
334
shard_config : shard_utils .ShardConfig | None = None ,
335
+ ignore_duplicates : bool = False ,
325
336
):
326
337
"""Init BeamWriter.
327
338
@@ -336,6 +347,8 @@ def __init__(
336
347
disable_shuffling: bool, specifies whether to shuffle the records.
337
348
example_writer: class that writes examples to storage.
338
349
shard_config: the configuration for creating shards.
350
+ ignore_duplicates: whether to ignore duplicated examples with the same
351
+ key. If False, a `DuplicatedKeysError` will be raised on duplicates.
339
352
"""
340
353
self ._original_state = dict (
341
354
serializer = serializer ,
@@ -344,6 +357,7 @@ def __init__(
344
357
disable_shuffling = disable_shuffling ,
345
358
shard_config = shard_config ,
346
359
example_writer = example_writer ,
360
+ ignore_duplicates = ignore_duplicates ,
347
361
)
348
362
self ._filename_template = filename_template
349
363
self ._split_info_path = (
@@ -355,6 +369,7 @@ def __init__(
355
369
self ._disable_shuffling = disable_shuffling
356
370
self ._shard_config = shard_config or shard_utils .ShardConfig ()
357
371
self ._example_writer = example_writer
372
+ self ._ignore_duplicates = ignore_duplicates
358
373
359
374
@functools .lru_cache ()
360
375
def _get_counter (self , name : str , namespace : str = "BeamWriter" ):
@@ -416,29 +431,34 @@ def _write_final_shard(
416
431
raise AssertionError ("Not a single example present in the PCollection!" )
417
432
# There may be empty shards, this ensures there are no gaps.
418
433
shard_id = non_empty_shard_ids .index (original_shard_id )
419
- examples = sorted (examples )
420
- self ._get_distribution (name = "ShardLenDistribution" ).update (len (examples ))
421
- # Compare continuous examples
422
- for ex0 , ex1 in zip (examples [:- 1 ], examples [1 :]):
423
- if ex0 [0 ] == ex1 [0 ]: # Different keys
424
- _raise_error_for_duplicated_keys (
425
- ex0 [1 ], ex1 [1 ], self ._serializer .example_specs
426
- )
434
+ example_by_key = {}
435
+ for key , example in examples :
436
+ if key in example_by_key :
437
+ if not self ._ignore_duplicates :
438
+ _raise_error_for_duplicated_keys (
439
+ example_by_key [key ], example , self ._serializer .example_specs
440
+ )
441
+ else :
442
+ example_by_key [key ] = example
427
443
shard_path = self ._filename_template .sharded_filepath (
428
444
shard_index = shard_id , num_shards = len (non_empty_shard_ids )
429
445
)
430
446
with utils .incomplete_file (epath .Path (shard_path )) as tmp_path :
431
447
logging .info (
432
- "Writing %d examples to %s." , len (examples ), os .fspath (tmp_path )
448
+ "Writing %d examples to %s." , len (example_by_key ), os .fspath (tmp_path )
449
+ )
450
+ record_keys = self ._example_writer .write (
451
+ tmp_path , sorted (example_by_key .items ())
433
452
)
434
- record_keys = self ._example_writer .write (tmp_path , examples )
435
453
self .inc_counter (name = "written_shards" )
436
454
# If there are record_keys, create index files.
437
455
if record_keys :
438
456
index_path = _get_index_path (os .fspath (shard_path ))
439
457
_write_index_file (index_path , list (record_keys ))
440
- shard_size = sum (map (len , examples ))
441
- return _ShardInfo (id = shard_id , num_examples = len (examples ), size = shard_size )
458
+ shard_size = sum (map (len , example_by_key .values ()))
459
+ return _ShardInfo (
460
+ id = shard_id , num_examples = len (example_by_key ), size = shard_size
461
+ )
442
462
443
463
def _number_of_shards (self , num_examples : int , total_size : int ) -> int :
444
464
"""Returns the number of shards."""
@@ -468,11 +488,11 @@ def _assign_shard(
468
488
def _store_split_info (
469
489
self ,
470
490
shard_infos : Sequence [_ShardInfo ],
471
- total_size : int ,
472
491
) -> None :
473
492
"""Stores the split info to disk."""
474
493
shard_infos = sorted (shard_infos , key = lambda x : x .id )
475
494
shard_lengths = [info .num_examples for info in shard_infos ]
495
+ total_size = sum ([info .size for info in shard_infos ])
476
496
with utils .incomplete_file (epath .Path (self ._split_info_path )) as tmp_path :
477
497
tmp_path .write_text (
478
498
json .dumps ({"total_size" : total_size , "shard_lengths" : shard_lengths })
@@ -553,8 +573,7 @@ def write_from_pcollection(self, examples_pcollection):
553
573
# (_ShardInfo)
554
574
| "CollectShardInfo" >> beam .transforms .combiners .ToList ()
555
575
# [_ShardInfo]
556
- | "CalculateSplitInfo"
557
- >> beam .ParDo (self ._store_split_info , total_size = total_size )
576
+ | "CalculateSplitInfo" >> beam .ParDo (self ._store_split_info )
558
577
)
559
578
560
579
def finalize (self ) -> tuple [list [int ], int ]:
0 commit comments