@@ -592,39 +592,47 @@ def test_write_beam(self, file_format: file_adapters.FileFormat):
592
592
593
593
with tempfile .TemporaryDirectory () as tmp_dir :
594
594
tmp_dir = epath .Path (tmp_dir )
595
- filename_template = naming .ShardedFileTemplate (
596
- dataset_name = 'foo' ,
597
- split = 'train' ,
598
- filetype_suffix = file_format .file_suffix ,
599
- data_dir = tmp_dir ,
600
- )
601
- writer = writer_lib .NoShuffleBeamWriter (
602
- serializer = testing .DummySerializer ('dummy specs' ),
603
- filename_template = filename_template ,
604
- file_format = file_format ,
605
- )
595
+
596
+ def get_writer (split ):
597
+ filename_template = naming .ShardedFileTemplate (
598
+ dataset_name = 'foo' ,
599
+ split = split ,
600
+ filetype_suffix = file_format .file_suffix ,
601
+ data_dir = tmp_dir ,
602
+ )
603
+ return writer_lib .NoShuffleBeamWriter (
604
+ serializer = testing .DummySerializer ('dummy specs' ),
605
+ filename_template = filename_template ,
606
+ file_format = file_format ,
607
+ )
608
+
606
609
to_write = [(i , str (i ).encode ('utf-8' )) for i in range (10 )]
607
610
# Here we need to disable type check as `beam.Create` is not capable of
608
611
# inferring the type of the PCollection elements.
609
612
options = beam .options .pipeline_options .PipelineOptions (
610
613
pipeline_type_check = False
611
614
)
612
- with beam . Pipeline ( options = options , runner = _get_runner ()) as pipeline :
613
-
614
- @ beam . ptransform_fn
615
- def _build_pcollection ( pipeline ) :
616
- pcollection = pipeline | 'Start' >> beam . Create ( to_write )
617
- return writer . write_from_pcollection ( pcollection )
618
-
619
- _ = pipeline | 'test ' >> _build_pcollection () # pylint: disable=no-value-for-parameter
620
- shard_lengths , total_size = writer .finalize ( )
621
- self . assertNotEmpty ( shard_lengths )
622
- self . assertEqual ( sum ( shard_lengths ), 10 )
623
- self . assertGreater ( total_size , 10 )
615
+ writers = [ get_writer ( split ) for split in ( 'train-b' , 'train' )]
616
+
617
+ for writer in writers :
618
+ with beam . Pipeline ( options = options , runner = _get_runner ()) as pipeline :
619
+
620
+ @ beam . ptransform_fn
621
+ def _build_pcollection ( pipeline , writer ):
622
+ pcollection = pipeline | 'Start ' >> beam . Create ( to_write )
623
+ return writer .write_from_pcollection ( pcollection )
624
+
625
+ _ = pipeline | 'test' >> _build_pcollection ( writer )
626
+
624
627
files = list (tmp_dir .iterdir ())
625
- self .assertGreaterEqual (len (files ), 1 )
628
+ self .assertGreaterEqual (len (files ), 2 )
626
629
for f in files :
627
630
self .assertIn (file_format .file_suffix , f .name )
631
+ for writer in writers :
632
+ shard_lengths , total_size = writer .finalize ()
633
+ self .assertNotEmpty (shard_lengths )
634
+ self .assertEqual (sum (shard_lengths ), 10 )
635
+ self .assertGreater (total_size , 10 )
628
636
629
637
630
638
class CustomExampleWriter (writer_lib .ExampleWriter ):
0 commit comments