|
26 | 26 | from typing import Any, ClassVar, Type, TypeVar
|
27 | 27 |
|
28 | 28 | from etils import epy
|
| 29 | +from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam |
29 | 30 | from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_module
|
30 | 31 | from tensorflow_datasets.core.utils.lazy_imports_utils import parquet as pq
|
31 | 32 | from tensorflow_datasets.core.utils.lazy_imports_utils import pyarrow as pa
|
32 | 33 | from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
|
33 | 34 |
|
| 35 | + |
34 | 36 | with epy.lazy_imports():
|
35 | 37 | # pylint: disable=g-import-not-at-top
|
36 | 38 | from etils import epath
|
| 39 | + from tensorflow_datasets.core import naming |
37 | 40 | from tensorflow_datasets.core.utils import file_utils
|
38 | 41 | from tensorflow_datasets.core.utils import type_utils
|
39 | 42 |
|
@@ -167,6 +170,23 @@ def deserialize(cls, raw_example: bytes) -> Any:
|
167 | 170 | """
|
168 | 171 | return tf.train.Example.FromString(raw_example)
|
169 | 172 |
|
| 173 | + @classmethod |
| 174 | + def beam_sink( |
| 175 | + cls, |
| 176 | + filename_template: naming.ShardedFileTemplate, |
| 177 | + num_shards: int | None = None, |
| 178 | + ) -> beam.PTransform: |
| 179 | + """Returns a Beam sink for writing examples in the given file format.""" |
| 180 | + raise NotImplementedError() |
| 181 | + |
| 182 | + @classmethod |
| 183 | + def num_examples(cls, filename: epath.PathLike) -> int: |
| 184 | + """Returns the number of examples in the given file.""" |
| 185 | + n = 0 |
| 186 | + for _ in cls.make_tf_data(filename): |
| 187 | + n += 1 |
| 188 | + return n |
| 189 | + |
170 | 190 |
|
171 | 191 | class TfRecordFileAdapter(FileAdapter):
|
172 | 192 | """File adapter for TFRecord file format."""
|
@@ -205,6 +225,20 @@ def write_examples(
|
205 | 225 | writer.write(serialized_example)
|
206 | 226 | writer.flush()
|
207 | 227 |
|
| 228 | + @classmethod |
| 229 | + def beam_sink( |
| 230 | + cls, |
| 231 | + filename_template: naming.ShardedFileTemplate, |
| 232 | + num_shards: int | None = None, |
| 233 | + ) -> beam.PTransform: |
| 234 | + """Returns a Beam sink for writing examples in the given file format.""" |
| 235 | + file_path_prefix = filename_template.sharded_filepaths_pattern( |
| 236 | + num_shards=num_shards, use_at_notation=True |
| 237 | + ).removesuffix('@*') |
| 238 | + return beam.io.WriteToTFRecord( |
| 239 | + file_path_prefix=file_path_prefix, num_shards=num_shards |
| 240 | + ) |
| 241 | + |
208 | 242 |
|
209 | 243 | class RiegeliFileAdapter(FileAdapter):
|
210 | 244 | """File adapter for Riegeli file format."""
|
|
0 commit comments