Skip to content

Commit 27e89c3

Browse files
author
The TensorFlow Datasets Authors
committed
Internal change
PiperOrigin-RevId: 640120829
1 parent 953b351 commit 27e89c3

File tree

6 files changed

+126
-27
lines changed

6 files changed

+126
-27
lines changed

tensorflow_datasets/core/dataset_info.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,10 +336,7 @@ def as_proto(self) -> dataset_info_pb2.DatasetInfo:
336336

337337
@property
338338
def as_proto_with_features(self) -> dataset_info_pb2.DatasetInfo:
339-
info_proto = dataset_info_pb2.DatasetInfo()
340-
info_proto.CopyFrom(self._info_proto)
341-
info_proto.features.CopyFrom(self.features.to_proto()) # pytype: disable=attribute-error # always-use-property-annotation
342-
return info_proto
339+
return update_info_proto_with_features(self._info_proto, self.features)
343340

344341
@property
345342
def name(self) -> str:
@@ -1196,6 +1193,25 @@ def _create_redistribution_info_proto(
11961193
return None
11971194

11981195

1196+
def update_info_proto_with_features(
1197+
info_proto: dataset_info_pb2.DatasetInfo,
1198+
features: feature_lib.FeatureConnector,
1199+
) -> dataset_info_pb2.DatasetInfo:
1200+
"""Update the info proto with the given features, if any.
1201+
1202+
Args:
1203+
info_proto: the info proto to update.
1204+
features: the features to use.
1205+
1206+
Returns:
1207+
the updated info proto.
1208+
"""
1209+
completed_info_proto = dataset_info_pb2.DatasetInfo()
1210+
completed_info_proto.CopyFrom(info_proto)
1211+
completed_info_proto.features.CopyFrom(features.to_proto())
1212+
return completed_info_proto
1213+
1214+
11991215
class MetadataDict(Metadata, dict):
12001216
"""A `tfds.core.Metadata` object that acts as a `dict`.
12011217

tensorflow_datasets/core/dataset_info_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,14 @@ def test_set_file_format_override(self):
422422
info.set_file_format(file_adapters.FileFormat.RIEGELI, override=True)
423423
self.assertEqual(info.file_format, file_adapters.FileFormat.RIEGELI)
424424

425+
def test_update_info_proto_with_features(self):
426+
info_proto = dataset_info.DatasetInfo(builder=self._builder).as_proto
427+
new_features = features.FeaturesDict({"text": features.Text()})
428+
new_info = dataset_info.update_info_proto_with_features(
429+
info_proto, new_features
430+
)
431+
self.assertEqual(new_info.features, new_features.to_proto())
432+
425433

426434
@pytest.mark.parametrize(
427435
"file_format",

tensorflow_datasets/core/features/feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1068,7 +1068,7 @@ def load_metadata(
10681068
will restore the feature metadata from the saved file.
10691069
10701070
Args:
1071-
data_dir: path to the dataset folder to which save the info (ex:
1071+
data_dir: path to the dataset folder where the info is saved (ex:
10721072
`~/datasets/cifar10/1.2.0/`)
10731073
feature_name: the name of the feature (from the FeaturesDict key)
10741074
"""

tensorflow_datasets/core/naming.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ class DatasetReference:
208208
can be used. For example, if the collection uses the split `valid`, but
209209
this dataset uses the split `validation`, then the `split_mapping` should
210210
be `{'validation': 'valid'}`.
211+
info_filenames: Filenames which are used to describe the dataset. They might
212+
include, for example, `dataset_info.json`, `features.json`, etc. If None,
213+
then it wasn't checked which info files exist on disk.
211214
""" # fmt: skip
212215

213216
dataset_name: str
@@ -216,6 +219,7 @@ class DatasetReference:
216219
version: None | str | version_lib.Version = None
217220
data_dir: None | str | os.PathLike = None # pylint: disable=g-bare-generic
218221
split_mapping: None | Mapping[str, str] = None
222+
info_filenames: set[str] | None = None
219223

220224
def __post_init__(self):
221225
if isinstance(self.version, version_lib.Version):
@@ -302,7 +306,7 @@ def from_tfds_name(
302306

303307

304308
def references_for(
305-
name_to_tfds_name: Mapping[str, str]
309+
name_to_tfds_name: Mapping[str, str],
306310
) -> Mapping[str, DatasetReference]:
307311
"""Constructs of dataset references.
308312

tensorflow_datasets/core/utils/file_utils.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,20 @@ def _find_files_without_glob(
177177

178178

179179
def _find_files_with_glob(
180-
folder: epath.Path, globs: list[str], file_names: list[str]
180+
folder: epath.Path,
181+
globs: list[str],
182+
file_names: list[str],
181183
) -> Iterator[epath.Path]:
182184
"""Finds files matching any of the given globs and given file names."""
183185
for glob in globs:
186+
found_files = folder.glob(glob)
184187
try:
185-
for file in folder.glob(glob):
188+
for file in found_files:
186189
if file.name in file_names:
187190
yield file
188-
except OSError:
191+
except (
192+
OSError,
193+
):
189194
# If permission was denied on any subfolder, then the glob fails. Manually
190195
# iterate through the subfolders instead to be more robust against this.
191196
yield from _find_files_without_glob(folder, globs, file_names)
@@ -197,6 +202,7 @@ def _find_references_with_glob(
197202
is_dataset_dir: bool,
198203
namespace: str | None = None,
199204
include_old_tfds_version: bool = True,
205+
glob_suffixes: Sequence[str] = ('json',),
200206
) -> Iterator[naming.DatasetReference]:
201207
"""Yields all dataset references in the given folder.
202208
@@ -208,6 +214,8 @@ def _find_references_with_glob(
208214
namespace: Optional namespace to which the found datasets belong to.
209215
include_old_tfds_version: include datasets that have been generated with
210216
TFDS before 4.0.0.
217+
glob_suffixes: list of file suffixes to use to create the the glob for
218+
interesting TFDS files. Defaults to json files.
211219
212220
Yields:
213221
all dataset references in the given folder.
@@ -220,16 +228,26 @@ def _find_references_with_glob(
220228
if is_data_dir:
221229
data_dir = folder
222230
dataset_name = None
223-
globs = ['*/*/*/*.json', '*/*/*.json']
231+
stars = ['*/*/*/*', '*/*/*']
224232
else:
225233
data_dir = folder.parent
226234
dataset_name = folder.name
227-
globs = ['*/*/*.json', '*/*.json']
235+
stars = ['*/*/*', '*/*']
236+
237+
globs = [f'{star}.{suffix}' for star in stars for suffix in glob_suffixes] # pylint:disable=g-complex-comprehension
228238

229239
# Check files matching the globs and are files we are interested in.
230240
matched_files_per_folder = collections.defaultdict(set)
231-
file_names = [constants.FEATURES_FILENAME, constants.DATASET_INFO_FILENAME]
232-
for file in _find_files_with_glob(folder, globs=globs, file_names=file_names):
241+
file_names = [
242+
constants.FEATURES_FILENAME,
243+
constants.DATASET_INFO_FILENAME,
244+
]
245+
246+
for file in _find_files_with_glob(
247+
folder,
248+
globs=globs,
249+
file_names=file_names,
250+
):
233251
matched_files_per_folder[file.parent].add(file.name)
234252

235253
for data_folder, matched_files in matched_files_per_folder.items():
@@ -284,6 +302,7 @@ def _find_references_with_glob(
284302
dataset_name=dataset_name,
285303
config=config,
286304
version=version,
305+
info_filenames=matched_files,
287306
)
288307

289308

@@ -292,6 +311,7 @@ def list_dataset_variants(
292311
namespace: str | None = None,
293312
include_versions: bool = True,
294313
include_old_tfds_version: bool = False,
314+
glob_suffixes: Sequence[str] = ('json',),
295315
) -> Iterator[naming.DatasetReference]:
296316
"""Yields all variants (config + version) found in `dataset_dir`.
297317
@@ -301,6 +321,8 @@ def list_dataset_variants(
301321
include_versions: whether to list what versions are available.
302322
include_old_tfds_version: include datasets that have been generated with
303323
TFDS before 4.0.0.
324+
glob_suffixes: list of file suffixes to use to create the the glob for
325+
interesting TFDS files. Defaults to json files.
304326
305327
Yields:
306328
all variants of the given dataset.
@@ -313,6 +335,7 @@ def list_dataset_variants(
313335
is_dataset_dir=True,
314336
namespace=namespace,
315337
include_old_tfds_version=include_old_tfds_version,
338+
glob_suffixes=glob_suffixes,
316339
):
317340
if include_versions:
318341
key = f'{reference.dataset_name}/{reference.config}:{reference.version}'

tensorflow_datasets/core/utils/file_utils_test.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,44 @@ def test_list_dataset_variants_with_configs(mock_fs: testing.MockFs):
4141
'x': ['1.0.0', '1.0.1'],
4242
'y': ['2.0.0'],
4343
}
44+
info_filenames = {
45+
'features.json',
46+
'dataset_info.json',
47+
}
48+
glob_suffixes = [
49+
'json',
50+
]
4451
for config, versions in configs_and_versions.items():
4552
for version in versions:
46-
mock_fs.add_file(dataset_dir / config / version / 'dataset_info.json')
47-
mock_fs.add_file(dataset_dir / config / version / 'features.json')
53+
for info_filename in info_filenames:
54+
mock_fs.add_file(dataset_dir / config / version / info_filename)
4855

49-
references = sorted(file_utils.list_dataset_variants(dataset_dir=dataset_dir))
56+
references = sorted(
57+
file_utils.list_dataset_variants(
58+
dataset_dir=dataset_dir, glob_suffixes=glob_suffixes
59+
)
60+
)
5061
assert references == [
5162
naming.DatasetReference(
52-
dataset_name='my_ds', config='x', version='1.0.0', data_dir=data_dir
63+
dataset_name='my_ds',
64+
config='x',
65+
version='1.0.0',
66+
data_dir=data_dir,
67+
info_filenames=info_filenames,
5368
),
5469
naming.DatasetReference(
55-
dataset_name='my_ds', config='x', version='1.0.1', data_dir=data_dir
70+
dataset_name='my_ds',
71+
config='x',
72+
version='1.0.1',
73+
data_dir=data_dir,
74+
info_filenames=info_filenames,
5675
),
5776
naming.DatasetReference(
58-
dataset_name='my_ds', config='y', version='2.0.0', data_dir=data_dir
77+
dataset_name='my_ds',
78+
config='y',
79+
version='2.0.0',
80+
data_dir=data_dir,
81+
info_filenames=info_filenames,
5982
),
6083
]
6184

@@ -69,10 +92,12 @@ def test_list_dataset_variants_with_configs_no_versions(
6992
'x': ['1.0.0', '1.0.1'],
7093
'y': ['2.0.0'],
7194
}
95+
info_filenames = {'dataset_info.json', 'features.json'}
7296
for config, versions in configs_and_versions.items():
7397
for version in versions:
74-
mock_fs.add_file(dataset_dir / config / version / 'dataset_info.json')
75-
mock_fs.add_file(dataset_dir / config / version / 'features.json')
98+
for filename in info_filenames:
99+
mock_fs.add_file(dataset_dir / config / version / filename)
100+
mock_fs.add_file(dataset_dir / config / version / filename)
76101

77102
references = sorted(
78103
file_utils.list_dataset_variants(
@@ -81,10 +106,16 @@ def test_list_dataset_variants_with_configs_no_versions(
81106
)
82107
assert references == [
83108
naming.DatasetReference(
84-
dataset_name='my_ds', config='x', data_dir=data_dir
109+
dataset_name='my_ds',
110+
config='x',
111+
data_dir=data_dir,
112+
info_filenames=info_filenames,
85113
),
86114
naming.DatasetReference(
87-
dataset_name='my_ds', config='y', data_dir=data_dir
115+
dataset_name='my_ds',
116+
config='y',
117+
data_dir=data_dir,
118+
info_filenames=info_filenames,
88119
),
89120
]
90121

@@ -108,10 +139,16 @@ def test_list_dataset_variants_without_configs(mock_fs: testing.MockFs):
108139
)
109140
assert references == [
110141
naming.DatasetReference(
111-
dataset_name='my_ds', version='1.0.0', data_dir=data_dir
142+
dataset_name='my_ds',
143+
version='1.0.0',
144+
data_dir=data_dir,
145+
info_filenames={'dataset_info.json'},
112146
),
113147
naming.DatasetReference(
114-
dataset_name='my_ds', version='1.0.1', data_dir=data_dir
148+
dataset_name='my_ds',
149+
version='1.0.1',
150+
data_dir=data_dir,
151+
info_filenames={'dataset_info.json', 'features.json'},
115152
),
116153
]
117154

@@ -125,7 +162,10 @@ def test_list_dataset_variants_without_configs(mock_fs: testing.MockFs):
125162
)
126163
assert references == [
127164
naming.DatasetReference(
128-
dataset_name='my_ds', version='1.0.1', data_dir=data_dir
165+
dataset_name='my_ds',
166+
version='1.0.1',
167+
data_dir=data_dir,
168+
info_filenames={'dataset_info.json', 'features.json'},
129169
)
130170
]
131171

@@ -140,6 +180,7 @@ def test_list_datasets_in_data_dir(mock_fs: testing.MockFs):
140180
mock_fs.add_file(data_dir / 'ds1/config2/1.0.0/features.json')
141181
mock_fs.add_file(data_dir / 'ds2/1.0.0/dataset_info.json')
142182
mock_fs.add_file(data_dir / 'ds2/1.0.0/features.json')
183+
info_filenames = {'dataset_info.json', 'features.json'}
143184

144185
# The following are problematic and should thus be ignored.
145186
mock_fs.add_file(
@@ -164,21 +205,27 @@ def test_list_datasets_in_data_dir(mock_fs: testing.MockFs):
164205
config='config1',
165206
version='1.0.0',
166207
data_dir=data_dir,
208+
info_filenames=info_filenames,
167209
),
168210
naming.DatasetReference(
169211
dataset_name='ds1',
170212
config='config1',
171213
version='2.0.0',
172214
data_dir=data_dir,
215+
info_filenames=info_filenames,
173216
),
174217
naming.DatasetReference(
175218
dataset_name='ds1',
176219
config='config2',
177220
version='1.0.0',
178221
data_dir=data_dir,
222+
info_filenames=info_filenames,
179223
),
180224
naming.DatasetReference(
181-
dataset_name='ds2', version='1.0.0', data_dir=data_dir
225+
dataset_name='ds2',
226+
version='1.0.0',
227+
data_dir=data_dir,
228+
info_filenames=info_filenames,
182229
),
183230
]
184231

@@ -205,6 +252,7 @@ def test_list_datasets_in_data_dir_with_namespace(mock_fs: testing.MockFs):
205252
config='config1',
206253
version='1.0.0',
207254
data_dir=data_dir,
255+
info_filenames={'dataset_info.json', 'features.json'},
208256
),
209257
]
210258

0 commit comments

Comments
 (0)