@@ -592,39 +592,40 @@ def test_write_beam(self, file_format: file_adapters.FileFormat):
592
592
593
593
with tempfile .TemporaryDirectory () as tmp_dir :
594
594
tmp_dir = epath .Path (tmp_dir )
595
- filename_template = naming .ShardedFileTemplate (
596
- dataset_name = 'foo' ,
597
- split = 'train' ,
598
- filetype_suffix = file_format .file_suffix ,
599
- data_dir = tmp_dir ,
600
- )
601
- writer = writer_lib .NoShuffleBeamWriter (
602
- serializer = testing .DummySerializer ('dummy specs' ),
603
- filename_template = filename_template ,
604
- file_format = file_format ,
605
- )
606
- to_write = [(i , str (i ).encode ('utf-8' )) for i in range (10 )]
607
- # Here we need to disable type check as `beam.Create` is not capable of
608
- # inferring the type of the PCollection elements.
609
- options = beam .options .pipeline_options .PipelineOptions (
610
- pipeline_type_check = False
611
- )
612
- with beam .Pipeline (options = options , runner = _get_runner ()) as pipeline :
613
-
614
- @beam .ptransform_fn
615
- def _build_pcollection (pipeline ):
616
- pcollection = pipeline | 'Start' >> beam .Create (to_write )
617
- return writer .write_from_pcollection (pcollection )
618
-
619
- _ = pipeline | 'test' >> _build_pcollection () # pylint: disable=no-value-for-parameter
620
- shard_lengths , total_size = writer .finalize ()
621
- self .assertNotEmpty (shard_lengths )
622
- self .assertEqual (sum (shard_lengths ), 10 )
623
- self .assertGreater (total_size , 10 )
624
- files = list (tmp_dir .iterdir ())
625
- self .assertGreaterEqual (len (files ), 1 )
626
- for f in files :
627
- self .assertIn (file_format .file_suffix , f .name )
595
+ for split in ('train-b' , 'train' ):
596
+ filename_template = naming .ShardedFileTemplate (
597
+ dataset_name = 'foo' ,
598
+ split = split ,
599
+ filetype_suffix = file_format .file_suffix ,
600
+ data_dir = tmp_dir ,
601
+ )
602
+ writer = writer_lib .NoShuffleBeamWriter (
603
+ serializer = testing .DummySerializer ('dummy specs' ),
604
+ filename_template = filename_template ,
605
+ file_format = file_format ,
606
+ )
607
+ to_write = [(i , str (i ).encode ('utf-8' )) for i in range (10 )]
608
+ # Here we need to disable type check as `beam.Create` is not capable
609
+ # of inferring the type of the PCollection elements.
610
+ options = beam .options .pipeline_options .PipelineOptions (
611
+ pipeline_type_check = False
612
+ )
613
+ with beam .Pipeline (options = options , runner = _get_runner ()) as pipeline :
614
+
615
+ @beam .ptransform_fn
616
+ def _build_pcollection (pipeline ):
617
+ pcollection = pipeline | 'Start' >> beam .Create (to_write )
618
+ return writer .write_from_pcollection (pcollection )
619
+
620
+ _ = pipeline | 'test' >> _build_pcollection () # pylint: disable=no-value-for-parameter
621
+ shard_lengths , total_size = writer .finalize ()
622
+ self .assertNotEmpty (shard_lengths )
623
+ self .assertEqual (sum (shard_lengths ), 10 )
624
+ self .assertGreater (total_size , 10 )
625
+ files = list (tmp_dir .iterdir ())
626
+ self .assertGreaterEqual (len (files ), 1 )
627
+ for f in files :
628
+ self .assertIn (file_format .file_suffix , f .name )
628
629
629
630
630
631
class CustomExampleWriter (writer_lib .ExampleWriter ):
0 commit comments