Skip to content

Commit 3e5515f

Browse files
author
The TensorFlow Datasets Authors
committed
Add a beam writer that doesn't shuffle
PiperOrigin-RevId: 693459753
1 parent 8ab25ad commit 3e5515f

File tree

2 files changed

+144
-4
lines changed

2 files changed

+144
-4
lines changed

tensorflow_datasets/core/writer.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,3 +717,97 @@ def finalize(self) -> tuple[list[int], int]:
717717
split_info_path.unlink()
718718

719719
return self._split_info["shard_lengths"], self._split_info["total_size"]
720+
721+
722+
class NoShuffleBeamWriter:
723+
"""Shuffles / writes Examples beam collection to sharded files."""
724+
725+
_OUTPUT_TAG_BUCKETS_LEN_SIZE = "tag_buckets_len_size"
726+
727+
def __init__(
728+
self,
729+
serializer: example_serializer.Serializer,
730+
filename_template: naming.ShardedFileTemplate,
731+
file_format: file_adapters.FileFormat,
732+
):
733+
"""Init BeamWriter.
734+
735+
Note that file "{filepath_prefix}.shard_lengths.json" is also created. It
736+
contains a list with the number of examples in each final shard. Eg:
737+
"[10,11,10,11]".
738+
739+
Args:
740+
serializer: class that can serialize examples.
741+
filename_template: template to format sharded filenames.
742+
file_format: the file format to use.
743+
"""
744+
self._original_state = dict(
745+
serializer=serializer,
746+
filename_template=filename_template,
747+
file_format=file_format,
748+
)
749+
self._file_format = file_format
750+
self._file_adapter = file_adapters.ADAPTER_FOR_FORMAT[self._file_format]
751+
self._filename_template = filename_template
752+
self._serializer = serializer
753+
754+
@functools.lru_cache()
755+
def _get_counter(self, name: str, namespace: str = "BeamWriter"):
756+
return beam.metrics.Metrics.counter(namespace, name)
757+
758+
def inc_counter(self, name: str, value: int = 1) -> None:
759+
self._get_counter(name).inc(value)
760+
761+
def __getstate__(self):
762+
return self._original_state
763+
764+
def __setstate__(self, state):
765+
self.__init__(**state)
766+
767+
def _serialize_example(
768+
self,
769+
key_example: tuple[hashing.HashKey, Example],
770+
) -> bytes:
771+
"""Returns (serialized_example)."""
772+
_, example = key_example
773+
self.inc_counter(name="serialized_examples")
774+
return self._serializer.serialize_example(example)
775+
776+
def write_from_pcollection(self, examples_pcollection):
777+
"""Returns PTransform to write (key, example) PCollection."""
778+
return (
779+
examples_pcollection
780+
| "Serialize" >> beam.Map(self._serialize_example)
781+
| "Write"
782+
>> self._file_adapter.beam_sink(
783+
filename_template=self._filename_template
784+
)
785+
)
786+
787+
def finalize(self) -> tuple[list[int], int]:
788+
"""Returns the computed shard_lengths and total_size.
789+
790+
Returns:
791+
List of length <number of shards> containing the number of examples stored
792+
in each shard, and size of the files (in bytes).
793+
"""
794+
logging.info("Finalizing writer for %s", self._filename_template.split)
795+
# We don't know the number of shards, the length of each shard, nor the
796+
# total size, so we compute them here.
797+
length_per_shard = {}
798+
total_size_bytes = 0
799+
prefix = epath.Path(self._filename_template.filepath_prefix())
800+
for shard in self._filename_template.data_dir.glob(f"{prefix.name}*"):
801+
length = self._file_adapter.num_examples(shard)
802+
length_per_shard[shard] = length
803+
total_size_bytes += shard.stat().length
804+
shard_lengths: list[int] = []
805+
for _, length in sorted(length_per_shard.items()):
806+
shard_lengths.append(length)
807+
logging.info(
808+
"Found %d shards with a total size of %d bytes.",
809+
len(shard_lengths),
810+
total_size_bytes,
811+
)
812+
813+
return shard_lengths, total_size_bytes

tensorflow_datasets/core/writer_test.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,20 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Tests for tensorflow_datasets.core.writer."""
17-
1816
import json
1917
import os
18+
import tempfile
2019
from typing import Optional
2120
from unittest import mock
2221

2322
from absl.testing import parameterized
23+
import apache_beam as beam
2424
from etils import epath
2525
import tensorflow as tf
2626
from tensorflow_datasets import testing
2727
from tensorflow_datasets.core import dataset_utils
2828
from tensorflow_datasets.core import example_parser
2929
from tensorflow_datasets.core import file_adapters
30-
from tensorflow_datasets.core import lazy_imports_lib
3130
from tensorflow_datasets.core import naming
3231
from tensorflow_datasets.core import writer as writer_lib
3332
from tensorflow_datasets.core.utils import shard_utils
@@ -409,6 +408,10 @@ def test_too_small_split(self):
409408
self._write(to_write=to_write)
410409

411410

411+
def _get_runner() -> beam.runners.PipelineRunner:
412+
return beam.runners.DirectRunner()
413+
414+
412415
class TfrecordsWriterBeamTest(testing.TestCase):
413416
NUM_SHARDS = 3
414417
RECORDS_TO_WRITE = [(i, str(i).encode('utf-8')) for i in range(10)]
@@ -455,7 +458,6 @@ def _write(
455458
shard_config = shard_config or shard_utils.ShardConfig(
456459
num_shards=self.NUM_SHARDS
457460
)
458-
beam = lazy_imports_lib.lazy_imports.apache_beam
459461
writer = writer_lib.BeamWriter(
460462
serializer=testing.DummySerializer('dummy specs'),
461463
filename_template=filename_template,
@@ -581,6 +583,50 @@ def test_write_tfrecord_sorted_by_key_with_holes(self):
581583
self.assertEmpty(all_indices)
582584

583585

586+
class NoShuffleBeamWriterTest(parameterized.TestCase):
587+
588+
@parameterized.named_parameters(
589+
('tfrecord', file_adapters.FileFormat.TFRECORD),
590+
)
591+
def test_write_beam(self, file_format: file_adapters.FileFormat):
592+
593+
with tempfile.TemporaryDirectory() as tmp_dir:
594+
tmp_dir = epath.Path(tmp_dir)
595+
filename_template = naming.ShardedFileTemplate(
596+
dataset_name='foo',
597+
split='train',
598+
filetype_suffix=file_format.file_suffix,
599+
data_dir=tmp_dir,
600+
)
601+
writer = writer_lib.NoShuffleBeamWriter(
602+
serializer=testing.DummySerializer('dummy specs'),
603+
filename_template=filename_template,
604+
file_format=file_format,
605+
)
606+
to_write = [(i, str(i).encode('utf-8')) for i in range(10)]
607+
# Here we need to disable type check as `beam.Create` is not capable of
608+
# inferring the type of the PCollection elements.
609+
options = beam.options.pipeline_options.PipelineOptions(
610+
pipeline_type_check=False
611+
)
612+
with beam.Pipeline(options=options, runner=_get_runner()) as pipeline:
613+
614+
@beam.ptransform_fn
615+
def _build_pcollection(pipeline):
616+
pcollection = pipeline | 'Start' >> beam.Create(to_write)
617+
return writer.write_from_pcollection(pcollection)
618+
619+
_ = pipeline | 'test' >> _build_pcollection() # pylint: disable=no-value-for-parameter
620+
shard_lengths, total_size = writer.finalize()
621+
self.assertNotEmpty(shard_lengths)
622+
self.assertEqual(sum(shard_lengths), 10)
623+
self.assertGreater(total_size, 10)
624+
files = list(tmp_dir.iterdir())
625+
self.assertGreaterEqual(len(files), 1)
626+
for f in files:
627+
self.assertIn(file_format.file_suffix, f.name)
628+
629+
584630
class CustomExampleWriter(writer_lib.ExampleWriter):
585631

586632
def __init__(self):

0 commit comments

Comments
 (0)