Skip to content

Commit cfa03b7

Browse files
author
The TensorFlow Datasets Authors
committed
Include a beam_sink class method to the file adapters.
PiperOrigin-RevId: 693445053
1 parent 882d2e3 commit cfa03b7

File tree

3 files changed

+55
-0
lines changed

3 files changed

+55
-0
lines changed

tensorflow_datasets/core/file_adapters.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,17 @@
2626
from typing import Any, ClassVar, Type, TypeVar
2727

2828
from etils import epy
29+
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
2930
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_module
3031
from tensorflow_datasets.core.utils.lazy_imports_utils import parquet as pq
3132
from tensorflow_datasets.core.utils.lazy_imports_utils import pyarrow as pa
3233
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
3334

35+
3436
with epy.lazy_imports():
3537
# pylint: disable=g-import-not-at-top
3638
from etils import epath
39+
from tensorflow_datasets.core import naming
3740
from tensorflow_datasets.core.utils import file_utils
3841
from tensorflow_datasets.core.utils import type_utils
3942

@@ -167,6 +170,23 @@ def deserialize(cls, raw_example: bytes) -> Any:
167170
"""
168171
return tf.train.Example.FromString(raw_example)
169172

173+
@classmethod
174+
def beam_sink(
175+
cls,
176+
filename_template: naming.ShardedFileTemplate,
177+
num_shards: int | None = None,
178+
) -> beam.PTransform:
179+
"""Returns a Beam sink for writing examples in the given file format."""
180+
raise NotImplementedError()
181+
182+
@classmethod
183+
def num_examples(cls, filename: epath.PathLike) -> int:
184+
"""Returns the number of examples in the given file."""
185+
n = 0
186+
for _ in cls.make_tf_data(filename):
187+
n += 1
188+
return n
189+
170190

171191
class TfRecordFileAdapter(FileAdapter):
172192
"""File adapter for TFRecord file format."""
@@ -205,6 +225,20 @@ def write_examples(
205225
writer.write(serialized_example)
206226
writer.flush()
207227

228+
@classmethod
229+
def beam_sink(
230+
cls,
231+
filename_template: naming.ShardedFileTemplate,
232+
num_shards: int | None = None,
233+
) -> beam.PTransform:
234+
"""Returns a Beam sink for writing examples in the given file format."""
235+
file_path_prefix = filename_template.sharded_filepaths_pattern(
236+
num_shards=num_shards, use_at_notation=True
237+
).removesuffix('@*')
238+
return beam.io.WriteToTFRecord(
239+
file_path_prefix=file_path_prefix, num_shards=num_shards
240+
)
241+
208242

209243
class RiegeliFileAdapter(FileAdapter):
210244
"""File adapter for Riegeli file format."""

tensorflow_datasets/core/naming.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,7 @@ def sharded_filepaths_pattern(
657657
self,
658658
*,
659659
num_shards: int | None = None,
660+
use_at_notation: bool = False,
660661
) -> str:
661662
"""Returns a pattern describing all the file paths captured by this template.
662663
@@ -668,13 +669,16 @@ def sharded_filepaths_pattern(
668669
669670
Args:
670671
num_shards: optional specification of the number of shards.
672+
use_at_notation: whether to return @* in case `num_shards` is `None`.
671673
672674
Returns:
673675
the pattern describing all shards captured by this template.
674676
"""
675677
a_filepath = self.sharded_filepath(shard_index=0, num_shards=1)
676678
if num_shards:
677679
replacement = f'@{num_shards}'
680+
elif use_at_notation:
681+
replacement = '@*'
678682
else:
679683
replacement = '*'
680684
return _replace_shard_pattern(os.fspath(a_filepath), replacement)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# coding=utf-8
2+
# Copyright 2024 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Lazy import utils untyped. Please, use lazy_imports_utils.py instead.
17+
"""

0 commit comments

Comments
 (0)