@@ -592,24 +592,29 @@ def test_write_beam(self, file_format: file_adapters.FileFormat):
592
592
593
593
with tempfile .TemporaryDirectory () as tmp_dir :
594
594
tmp_dir = epath .Path (tmp_dir )
595
- for split in ('train-b' , 'train' ):
595
+
596
+ def get_writer (split ):
596
597
filename_template = naming .ShardedFileTemplate (
597
598
dataset_name = 'foo' ,
598
599
split = split ,
599
600
filetype_suffix = file_format .file_suffix ,
600
601
data_dir = tmp_dir ,
601
602
)
602
- writer = writer_lib .NoShuffleBeamWriter (
603
+ return writer_lib .NoShuffleBeamWriter (
603
604
serializer = testing .DummySerializer ('dummy specs' ),
604
605
filename_template = filename_template ,
605
606
file_format = file_format ,
606
607
)
607
- to_write = [(i , str (i ).encode ('utf-8' )) for i in range (10 )]
608
- # Here we need to disable type check as `beam.Create` is not capable
609
- # of inferring the type of the PCollection elements.
610
- options = beam .options .pipeline_options .PipelineOptions (
611
- pipeline_type_check = False
612
- )
608
+
609
+ to_write = [(i , str (i ).encode ('utf-8' )) for i in range (10 )]
610
+ # Here we need to disable type check as `beam.Create` is not capable of
611
+ # inferring the type of the PCollection elements.
612
+ options = beam .options .pipeline_options .PipelineOptions (
613
+ pipeline_type_check = False
614
+ )
615
+ writers = [get_writer (split ) for split in ('train-b' , 'train' )]
616
+
617
+ for writer in writers :
613
618
with beam .Pipeline (options = options , runner = _get_runner ()) as pipeline :
614
619
615
620
@beam .ptransform_fn
@@ -618,14 +623,16 @@ def _build_pcollection(pipeline):
618
623
return writer .write_from_pcollection (pcollection )
619
624
620
625
_ = pipeline | 'test' >> _build_pcollection () # pylint: disable=no-value-for-parameter
626
+
627
+ files = list (tmp_dir .iterdir ())
628
+ self .assertGreaterEqual (len (files ), 2 )
629
+ for f in files :
630
+ self .assertIn (file_format .file_suffix , f .name )
631
+ for writer in writers :
621
632
shard_lengths , total_size = writer .finalize ()
622
633
self .assertNotEmpty (shard_lengths )
623
634
self .assertEqual (sum (shard_lengths ), 10 )
624
635
self .assertGreater (total_size , 10 )
625
- files = list (tmp_dir .iterdir ())
626
- self .assertGreaterEqual (len (files ), 1 )
627
- for f in files :
628
- self .assertIn (file_format .file_suffix , f .name )
629
636
630
637
631
638
class CustomExampleWriter (writer_lib .ExampleWriter ):
0 commit comments