|
31 | 31 |
|
32 | 32 | from absl import logging
|
33 | 33 | from etils import epath
|
| 34 | +from tensorflow_datasets.core import file_adapters |
34 | 35 | from tensorflow_datasets.core import naming
|
35 | 36 | from tensorflow_datasets.core import proto as proto_lib
|
36 | 37 | from tensorflow_datasets.core import units
|
@@ -124,7 +125,7 @@ def from_proto(
|
124 | 125 | cls,
|
125 | 126 | proto: proto_lib.SplitInfo,
|
126 | 127 | filename_template: naming.ShardedFileTemplate,
|
127 |
| - ) -> 'SplitInfo': |
| 128 | + ) -> SplitInfo: |
128 | 129 | """Returns a SplitInfo class instance from a SplitInfo proto."""
|
129 | 130 | return cls(
|
130 | 131 | name=proto.name,
|
@@ -217,10 +218,32 @@ def filepaths(self) -> list[epath.Path]:
|
217 | 218 | self.filename_template.sharded_filepaths(len(self.shard_lengths))
|
218 | 219 | )
|
219 | 220 |
|
220 |
| - def replace(self, **kwargs: Any) -> 'SplitInfo': |
| 221 | + def replace(self, **kwargs: Any) -> SplitInfo: |
221 | 222 | """Returns a copy of the `SplitInfo` with updated attributes."""
|
222 | 223 | return dataclasses.replace(self, **kwargs)
|
223 | 224 |
|
| 225 | + def file_spec(self, file_format: file_adapters.FileFormat) -> str: |
| 226 | + """Returns the file spec of the split for the given file format. |
| 227 | +
|
| 228 | + A file spec is the full path with sharded notation, e.g., |
| 229 | + `/data/ds/cfg/1.2.3/ds-train@10`. |
| 230 | +
|
| 231 | + Args: |
| 232 | + file_format: the file format for which to create the file spec for. |
| 233 | + """ |
| 234 | + if filename_template := self.filename_template: |
| 235 | + if filename_template.filetype_suffix != file_format.file_suffix: |
| 236 | + raise ValueError( |
| 237 | + f'File format {file_format} does not match filename template' |
| 238 | + f' {filename_template}.' |
| 239 | + ) |
| 240 | + return filename_template.sharded_filepaths_pattern( |
| 241 | + num_shards=self.num_shards |
| 242 | + ) |
| 243 | + raise ValueError( |
| 244 | + f'Could not get filename template for split from split info: {self}.' |
| 245 | + ) |
| 246 | + |
224 | 247 |
|
225 | 248 | @dataclasses.dataclass(eq=False, frozen=True)
|
226 | 249 | class MultiSplitInfo(SplitInfo):
|
|
0 commit comments