Skip to content

Commit fa4eda5

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Add file spec method to SplitInfo
Returns the file spec, which is the full path with sharded notation, e.g., `/data/ds/cfg/1.2.3/ds-train@10`. PiperOrigin-RevId: 664793111
1 parent c815f93 commit fa4eda5

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

tensorflow_datasets/core/splits.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from absl import logging
3333
from etils import epath
34+
from tensorflow_datasets.core import file_adapters
3435
from tensorflow_datasets.core import naming
3536
from tensorflow_datasets.core import proto as proto_lib
3637
from tensorflow_datasets.core import units
@@ -124,7 +125,7 @@ def from_proto(
124125
cls,
125126
proto: proto_lib.SplitInfo,
126127
filename_template: naming.ShardedFileTemplate,
127-
) -> 'SplitInfo':
128+
) -> SplitInfo:
128129
"""Returns a SplitInfo class instance from a SplitInfo proto."""
129130
return cls(
130131
name=proto.name,
@@ -217,10 +218,32 @@ def filepaths(self) -> list[epath.Path]:
217218
self.filename_template.sharded_filepaths(len(self.shard_lengths))
218219
)
219220

220-
def replace(self, **kwargs: Any) -> 'SplitInfo':
221+
def replace(self, **kwargs: Any) -> SplitInfo:
221222
"""Returns a copy of the `SplitInfo` with updated attributes."""
222223
return dataclasses.replace(self, **kwargs)
223224

225+
def file_spec(self, file_format: file_adapters.FileFormat) -> str:
226+
"""Returns the file spec of the split for the given file format.
227+
228+
A file spec is the full path with sharded notation, e.g.,
229+
`/data/ds/cfg/1.2.3/ds-train@10`.
230+
231+
Args:
232+
file_format: the file format for which to create the file spec for.
233+
"""
234+
if filename_template := self.filename_template:
235+
if filename_template.filetype_suffix != file_format.file_suffix:
236+
raise ValueError(
237+
f'File format {file_format} does not match filename template'
238+
f' {filename_template}.'
239+
)
240+
return filename_template.sharded_filepaths_pattern(
241+
num_shards=self.num_shards
242+
)
243+
raise ValueError(
244+
f'Could not get filename template for split from split info: {self}.'
245+
)
246+
224247

225248
@dataclasses.dataclass(eq=False, frozen=True)
226249
class MultiSplitInfo(SplitInfo):

tensorflow_datasets/core/splits_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,5 +641,34 @@ def test_missing_shard_lengths(self):
641641
self.assertEqual(files, [])
642642

643643

644+
class SplitInfoTest(testing.TestCase):
645+
646+
def test_file_spec(self):
647+
split_info = tfds.core.SplitInfo(
648+
name='train',
649+
shard_lengths=[1, 2, 3],
650+
num_bytes=42,
651+
filename_template=_filename_template(split='train'),
652+
)
653+
self.assertEqual(
654+
split_info.file_spec(
655+
file_format=tfds.core.file_adapters.FileFormat.TFRECORD
656+
),
657+
'/path/ds_name-train.tfrecord@3',
658+
)
659+
660+
def test_file_spec_missing_template(self):
661+
split_info = tfds.core.SplitInfo(
662+
name='train',
663+
shard_lengths=[1, 2, 3],
664+
num_bytes=42,
665+
filename_template=None,
666+
)
667+
with self.assertRaises(ValueError):
668+
split_info.file_spec(
669+
file_format=tfds.core.file_adapters.FileFormat.TFRECORD
670+
)
671+
672+
644673
if __name__ == '__main__':
645674
testing.test_main()

0 commit comments

Comments
 (0)