Skip to content

Commit 6f96eaa

Browse files
committed
Add test to verify correct shard computation with overlapping splits
1 parent 132b530 commit 6f96eaa

File tree

1 file changed

+34
-33
lines changed

1 file changed

+34
-33
lines changed

tensorflow_datasets/core/writer_test.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -592,39 +592,40 @@ def test_write_beam(self, file_format: file_adapters.FileFormat):
592592

593593
with tempfile.TemporaryDirectory() as tmp_dir:
594594
tmp_dir = epath.Path(tmp_dir)
595-
filename_template = naming.ShardedFileTemplate(
596-
dataset_name='foo',
597-
split='train',
598-
filetype_suffix=file_format.file_suffix,
599-
data_dir=tmp_dir,
600-
)
601-
writer = writer_lib.NoShuffleBeamWriter(
602-
serializer=testing.DummySerializer('dummy specs'),
603-
filename_template=filename_template,
604-
file_format=file_format,
605-
)
606-
to_write = [(i, str(i).encode('utf-8')) for i in range(10)]
607-
# Here we need to disable type check as `beam.Create` is not capable of
608-
# inferring the type of the PCollection elements.
609-
options = beam.options.pipeline_options.PipelineOptions(
610-
pipeline_type_check=False
611-
)
612-
with beam.Pipeline(options=options, runner=_get_runner()) as pipeline:
613-
614-
@beam.ptransform_fn
615-
def _build_pcollection(pipeline):
616-
pcollection = pipeline | 'Start' >> beam.Create(to_write)
617-
return writer.write_from_pcollection(pcollection)
618-
619-
_ = pipeline | 'test' >> _build_pcollection() # pylint: disable=no-value-for-parameter
620-
shard_lengths, total_size = writer.finalize()
621-
self.assertNotEmpty(shard_lengths)
622-
self.assertEqual(sum(shard_lengths), 10)
623-
self.assertGreater(total_size, 10)
624-
files = list(tmp_dir.iterdir())
625-
self.assertGreaterEqual(len(files), 1)
626-
for f in files:
627-
self.assertIn(file_format.file_suffix, f.name)
595+
for split in ('train-b', 'train'):
596+
filename_template = naming.ShardedFileTemplate(
597+
dataset_name='foo',
598+
split=split,
599+
filetype_suffix=file_format.file_suffix,
600+
data_dir=tmp_dir,
601+
)
602+
writer = writer_lib.NoShuffleBeamWriter(
603+
serializer=testing.DummySerializer('dummy specs'),
604+
filename_template=filename_template,
605+
file_format=file_format,
606+
)
607+
to_write = [(i, str(i).encode('utf-8')) for i in range(10)]
608+
# Here we need to disable type check as `beam.Create` is not capable
609+
# of inferring the type of the PCollection elements.
610+
options = beam.options.pipeline_options.PipelineOptions(
611+
pipeline_type_check=False
612+
)
613+
with beam.Pipeline(options=options, runner=_get_runner()) as pipeline:
614+
615+
@beam.ptransform_fn
616+
def _build_pcollection(pipeline):
617+
pcollection = pipeline | 'Start' >> beam.Create(to_write)
618+
return writer.write_from_pcollection(pcollection)
619+
620+
_ = pipeline | 'test' >> _build_pcollection() # pylint: disable=no-value-for-parameter
621+
shard_lengths, total_size = writer.finalize()
622+
self.assertNotEmpty(shard_lengths)
623+
self.assertEqual(sum(shard_lengths), 10)
624+
self.assertGreater(total_size, 10)
625+
files = list(tmp_dir.iterdir())
626+
self.assertGreaterEqual(len(files), 1)
627+
for f in files:
628+
self.assertIn(file_format.file_suffix, f.name)
628629

629630

630631
class CustomExampleWriter(writer_lib.ExampleWriter):

0 commit comments

Comments
 (0)