Skip to content

Commit 7d88e04

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Introduce a ConvertConfig dataclass to not have to pass around so many parameters
Also introduce an option to not fail the entire conversion pipeline when a single shard has an error. PiperOrigin-RevId: 688089477
1 parent a08d8d5 commit 7d88e04

File tree

6 files changed

+334
-149
lines changed

6 files changed

+334
-149
lines changed

tensorflow_datasets/core/naming.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,17 @@ def sharded_filepaths_pattern(
691691
replacement = '*'
692692
return _replace_shard_pattern(os.fspath(a_filepath), replacement)
693693

694+
def glob_pattern(self, num_shards: int | None = None) -> str:
695+
"""Returns a glob pattern for all the file paths captured by this template."""
696+
if num_shards is None:
697+
# e.g., `dataset_name-split.fileformat*`
698+
return self.sharded_filepaths_pattern(num_shards=None)
699+
first_shard = self.sharded_filepath(shard_index=0, num_shards=num_shards)
700+
file_name = first_shard.name
701+
file_pattern = re.sub(r'0{5,}-of-', '*-of-', file_name)
702+
# e.g., `dataset_name-split.fileformat-*-of-00042`
703+
return os.fspath(first_shard.parent / file_pattern)
704+
694705
def sharded_filenames(self, num_shards: int) -> list[str]:
695706
return [path.name for path in self.sharded_filepaths(num_shards=num_shards)]
696707

tensorflow_datasets/core/naming_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,19 @@ def test_sharded_file_template_shard_index():
516516
)
517517

518518

519+
def test_glob_pattern():
520+
template = naming.ShardedFileTemplate(
521+
dataset_name='ds',
522+
split='train',
523+
filetype_suffix='tfrecord',
524+
data_dir=epath.Path('/data'),
525+
)
526+
assert '/data/ds-train.tfrecord*' == template.glob_pattern()
527+
assert '/data/ds-train.tfrecord-*-of-00042' == template.glob_pattern(
528+
num_shards=42
529+
)
530+
531+
519532
def test_sharded_file_template_sharded_filepath_shard_x_of_y():
520533
builder_dir = epath.Path('/my/path')
521534
template_explicit = naming.ShardedFileTemplate(

tensorflow_datasets/core/splits.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,38 @@ def __post_init__(self):
120120
# Normalize bytes
121121
super().__setattr__('num_bytes', units.Size(self.num_bytes))
122122

123+
def get_available_shards(
124+
self,
125+
data_dir: epath.Path | None = None,
126+
file_format: file_adapters.FileFormat | None = None,
127+
strict_matching: bool = True,
128+
) -> list[epath.Path]:
129+
"""Returns the list of shards that are present in the data dir.
130+
131+
Args:
132+
data_dir: The data directory to look for shards in. If not provided, the
133+
data directory from the filename template is used.
134+
file_format: The file format to look for shards in. If not provided, the
135+
file format from the filename template is used.
136+
strict_matching: If True, only shards that match the filename template
137+
exactly are returned taking into account the number of shards.
138+
Otherwise, shards that match the template with a wildcard for the shard
139+
number are returned.
140+
"""
141+
if filename_template := self.filename_template:
142+
if file_format:
143+
filename_template = filename_template.replace(
144+
filetype_suffix=file_format.file_suffix
145+
)
146+
data_dir = data_dir or filename_template.data_dir
147+
if strict_matching:
148+
pattern = filename_template.glob_pattern(num_shards=self.num_shards)
149+
else:
150+
pattern = filename_template.sharded_filepaths_pattern(num_shards=None)
151+
return list(data_dir.glob(pattern))
152+
else:
153+
raise ValueError(f'Filename template for split {self.name} is empty.')
154+
123155
@classmethod
124156
def from_proto(
125157
cls,
@@ -382,7 +414,7 @@ class Split(str):
382414
"""
383415

384416
def __repr__(self) -> str:
385-
return '{}({})'.format(type(self).__name__, super(Split, self).__repr__()) # pytype: disable=wrong-arg-types
417+
return f'{type(self).__name__}({super().__repr__()})'
386418

387419

388420
Split.TRAIN = Split('train')
@@ -735,7 +767,9 @@ def _str_to_relative_instruction(spec: str) -> AbstractSplit:
735767
else: # split='train[x:y]' or split='train[x]'
736768
slices = [_SLICE_RE.match(x) for x in split_selector.split(':')]
737769
# Make sure all slices are valid, and at least one is not empty
738-
if not all(slices) or not any(x.group(0) for x in slices): # pytype: disable=attribute-error # re-none
770+
if not all(slices) or not any(
771+
x.group(0) for x in slices if x is not None
772+
): # re-none
739773
raise ValueError(err_msg)
740774
if len(slices) == 1: # split='train[x]'
741775
(from_match,) = slices

tensorflow_datasets/core/splits_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
"""Tests for the Split API."""
1717

18+
import os
19+
from etils import epath
1820
from tensorflow_datasets import testing
1921
from tensorflow_datasets.core import naming
2022
from tensorflow_datasets.core import proto
@@ -669,6 +671,32 @@ def test_file_spec_missing_template(self):
669671
file_format=tfds.core.file_adapters.FileFormat.TFRECORD
670672
)
671673

674+
def test_get_available_shards(self):
675+
tmp_dir = epath.Path(self.tmp_dir)
676+
train_shard1 = tmp_dir / 'ds-train.tfrecord-00000-of-00002'
677+
train_shard1.touch()
678+
train_shard_incorrect = tmp_dir / 'ds-train.tfrecord-00000-of-12345'
679+
train_shard_incorrect.touch()
680+
test_shard1 = tmp_dir / 'ds-test.tfrecord-00000-of-00001'
681+
test_shard1.touch()
682+
683+
split_info = splits.SplitInfo(
684+
name='train',
685+
shard_lengths=[1, 2],
686+
num_bytes=42,
687+
filename_template=_filename_template(
688+
split='train', data_dir=os.fspath(tmp_dir), dataset_name='ds'
689+
),
690+
)
691+
self.assertEqual(
692+
[train_shard1, train_shard_incorrect],
693+
split_info.get_available_shards(tmp_dir, strict_matching=False),
694+
)
695+
self.assertEqual(
696+
[train_shard1],
697+
split_info.get_available_shards(tmp_dir, strict_matching=True),
698+
)
699+
672700

673701
if __name__ == '__main__':
674702
testing.test_main()

0 commit comments

Comments
 (0)