Skip to content

Commit 5de1d20

Browse files
committed
Run both writers before calling testing finalize output
1 parent 6f96eaa commit 5de1d20

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

tensorflow_datasets/core/writer_test.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -592,24 +592,29 @@ def test_write_beam(self, file_format: file_adapters.FileFormat):
592592

593593
with tempfile.TemporaryDirectory() as tmp_dir:
594594
tmp_dir = epath.Path(tmp_dir)
595-
for split in ('train-b', 'train'):
595+
596+
def get_writer(split):
596597
filename_template = naming.ShardedFileTemplate(
597598
dataset_name='foo',
598599
split=split,
599600
filetype_suffix=file_format.file_suffix,
600601
data_dir=tmp_dir,
601602
)
602-
writer = writer_lib.NoShuffleBeamWriter(
603+
return writer_lib.NoShuffleBeamWriter(
603604
serializer=testing.DummySerializer('dummy specs'),
604605
filename_template=filename_template,
605606
file_format=file_format,
606607
)
607-
to_write = [(i, str(i).encode('utf-8')) for i in range(10)]
608-
# Here we need to disable type check as `beam.Create` is not capable
609-
# of inferring the type of the PCollection elements.
610-
options = beam.options.pipeline_options.PipelineOptions(
611-
pipeline_type_check=False
612-
)
608+
609+
to_write = [(i, str(i).encode('utf-8')) for i in range(10)]
610+
# Here we need to disable type check as `beam.Create` is not capable of
611+
# inferring the type of the PCollection elements.
612+
options = beam.options.pipeline_options.PipelineOptions(
613+
pipeline_type_check=False
614+
)
615+
writers = [get_writer(split) for split in ('train-b', 'train')]
616+
617+
for writer in writers:
613618
with beam.Pipeline(options=options, runner=_get_runner()) as pipeline:
614619

615620
@beam.ptransform_fn
@@ -618,14 +623,16 @@ def _build_pcollection(pipeline):
618623
return writer.write_from_pcollection(pcollection)
619624

620625
_ = pipeline | 'test' >> _build_pcollection() # pylint: disable=no-value-for-parameter
626+
627+
files = list(tmp_dir.iterdir())
628+
self.assertGreaterEqual(len(files), 2)
629+
for f in files:
630+
self.assertIn(file_format.file_suffix, f.name)
631+
for writer in writers:
621632
shard_lengths, total_size = writer.finalize()
622633
self.assertNotEmpty(shard_lengths)
623634
self.assertEqual(sum(shard_lengths), 10)
624635
self.assertGreater(total_size, 10)
625-
files = list(tmp_dir.iterdir())
626-
self.assertGreaterEqual(len(files), 1)
627-
for f in files:
628-
self.assertIn(file_format.file_suffix, f.name)
629636

630637

631638
class CustomExampleWriter(writer_lib.ExampleWriter):

0 commit comments

Comments
 (0)