Skip to content

Commit 882d2e3

Browse files
author
The TensorFlow Datasets Authors
committed
Include metadata conversion in the beam pipeline.
PiperOrigin-RevId: 693278977
1 parent c96a3ed commit 882d2e3

File tree

2 files changed

+113
-40
lines changed

2 files changed

+113
-40
lines changed

tensorflow_datasets/scripts/cli/convert_format_utils.py

Lines changed: 101 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,16 @@ class ShardInstruction:
105105
out_path: epath.Path
106106
config: ConvertConfig
107107

108-
def convert(self) -> None:
109-
"""Converts the shard to the desired file format."""
108+
def convert(self) -> epath.Path | None:
109+
"""Converts the shard to the desired file format.
110+
111+
Returns:
112+
The path of the converted shard or `None` if the shard was not converted.
113+
114+
Raises:
115+
Exception: if the shard conversion failed and `config.fail_on_error` is
116+
`True`, else logs the error.
117+
"""
110118

111119
def read_in() -> Iterator[type_utils.KeySerializedExample]:
112120
in_dataset = self.config.in_file_adapter.make_tf_data(
@@ -127,12 +135,13 @@ def read_in() -> Iterator[type_utils.KeySerializedExample]:
127135
self.config.out_file_adapter.write_examples(
128136
path=tmp_file, iterator=read_in()
129137
)
138+
return self.out_path
130139
except Exception as e: # pylint: disable=broad-except
131140
if self.config.fail_on_error:
132141
raise e
133142
else:
134143
logging.exception(
135-
'Failed to convert shard %s (format=%s) to %s (format=%s!',
144+
'Failed to convert shard %s (format=%s) to %s (format=%s)!',
136145
self.in_path,
137146
self.config.in_file_adapter.FILE_SUFFIX,
138147
self.out_path,
@@ -220,45 +229,87 @@ def _get_root_data_dir(
220229
return epath.Path(re.sub(rf'{relative_data_dir}/?$', '', in_dir))
221230

222231

232+
class ConvertMetadataFn(beam.DoFn):
233+
"""Beam DoFn to convert metadata for a single dataset version."""
234+
235+
def process(
236+
self,
237+
count,
238+
in_dir: epath.Path,
239+
info: dataset_info_pb2.DatasetInfo,
240+
out_path: epath.Path,
241+
convert_config: ConvertConfig,
242+
):
243+
# This is necessary because sometimes `beam.combiners.Count.Globally()`
244+
# returned an integer and not a pCollection.
245+
if not isinstance(count, int):
246+
count = beam.pvalue.AsSingleton(count)
247+
convert_metadata(
248+
in_dir=in_dir,
249+
info=info,
250+
out_path=out_path,
251+
convert_config=convert_config,
252+
num_converted_shards=count,
253+
)
254+
255+
223256
def convert_metadata(
224257
in_dir: epath.Path,
225258
info: dataset_info_pb2.DatasetInfo,
226-
out_file_format: file_adapters.FileFormat,
227259
out_path: epath.Path,
260+
convert_config: ConvertConfig,
261+
num_converted_shards: int | None = None,
228262
) -> None:
229263
"""Converts all metadata to the converted dataset.
230264
231265
Args:
232266
in_dir: folder that contains the dataset to convert.
233267
info: dataset info of the dataset to convert.
234-
out_file_format: the format to which the dataset should be converted to.
235268
out_path: folder where the converted dataset should be stored.
269+
convert_config: configuration for the conversion.
270+
num_converted_shards: number of shards that were successfully converted,
271+
which is used to check that the conversion was successful. If part of a
272+
beam pipeline, this comes from `beam.combiners.Count.Globally()`.
236273
"""
237274
splits_dict = dataset_info_lib.get_split_dict_from_proto(
238275
dataset_info_proto=info,
239276
data_dir=in_dir,
240-
file_format=out_file_format,
277+
file_format=convert_config.out_file_format,
241278
)
242279

243280
missing_shards_per_split = {}
244281
for split_info in splits_dict.values():
245-
available_shards = split_info.get_available_shards(
246-
out_path, file_format=out_file_format
282+
num_available_shards = len(
283+
split_info.get_available_shards(
284+
out_path, file_format=convert_config.out_file_format
285+
)
247286
)
248-
if len(available_shards) < split_info.num_shards:
287+
if num_converted_shards != num_available_shards:
288+
logging.warning(
289+
'Amount of converted shards calculated during conversion (%d) does'
290+
' not match the amount of available shards in the data dir (%d) for'
291+
' split %s.'
292+
)
293+
if num_available_shards < split_info.num_shards:
249294
missing_shards_per_split[split_info.name] = (
250-
len(available_shards),
295+
num_available_shards,
251296
split_info.num_shards,
252297
)
253-
logging.warning(
254-
'Found %d shards for split %s, but expected %d shards.',
255-
len(available_shards),
256-
split_info.name,
257-
split_info.num_shards,
298+
error_message = (
299+
(
300+
f'Found {num_available_shards} shards for split'
301+
f' {split_info.name}, but expected'
302+
f' {split_info.num_shards} shards.'
303+
),
258304
)
259-
elif len(available_shards) > split_info.num_shards:
305+
if convert_config.fail_on_error:
306+
raise ValueError(error_message)
307+
else:
308+
logging.warning(error_message)
309+
310+
elif num_available_shards > split_info.num_shards:
260311
raise ValueError(
261-
f'Found more shards ({len(available_shards)}) for split'
312+
f'Found more shards ({num_available_shards}) for split'
262313
f' {split_info.name}, but expected only'
263314
f' {split_info.num_shards} shards.'
264315
)
@@ -278,14 +329,14 @@ def convert_metadata(
278329

279330
# File format was added to an existing dataset.
280331
# Add the file format to `alternative_file_formats` field.
281-
if out_file_format not in info.alternative_file_formats:
282-
info.alternative_file_formats.append(out_file_format.value)
332+
if convert_config.out_file_format not in info.alternative_file_formats:
333+
info.alternative_file_formats.append(convert_config.out_file_format.value)
283334
dataset_info_lib.write_dataset_info_proto(info, dataset_info_dir=out_path)
284335
else:
285336
logging.info(
286337
'File format %s is already an alternative file format of the dataset'
287338
' in %s. Skipping updating metadata..',
288-
out_file_format.value,
339+
convert_config.out_file_format.value,
289340
os.fspath(in_dir),
290341
)
291342
return
@@ -318,7 +369,7 @@ def convert_metadata(
318369
dataset_info_proto=info,
319370
dataset_reference=in_dataset_reference,
320371
)
321-
info.file_format = out_file_format.value
372+
info.file_format = convert_config.out_file_format.value
322373
dataset_info_lib.write_dataset_info_proto(info, dataset_info_dir=out_path)
323374

324375

@@ -359,21 +410,46 @@ def _convert_dataset(
359410
logging.info('Found %d shards to convert.', len(shard_instructions))
360411

361412
if pipeline is not None:
362-
_ = (
413+
converted_shards = (
363414
pipeline
364415
| f'CreateShardInstructions for {dataset_dir}'
365416
>> beam.Create(shard_instructions)
366417
| f'ConvertShards for {dataset_dir}'
367418
>> beam.Map(lambda shard_instruction: shard_instruction.convert())
419+
| f'Filter out shards that were not successfully converted for {dataset_dir}'
420+
>> beam.Filter(lambda shard_instruction: shard_instruction is not None)
421+
)
422+
count_shards = (
423+
converted_shards
424+
| f'CountConvertedShards for {dataset_dir}'
425+
>> beam.combiners.Count.Globally()
426+
)
427+
_ = count_shards | f'ConvertMetadata for {dataset_dir}' >> beam.ParDo(
428+
ConvertMetadataFn(),
429+
in_dir=dataset_dir,
430+
info=info,
431+
out_path=out_dir,
432+
convert_config=convert_config,
368433
)
369434

370435
else:
436+
converted_shards = 0
371437
for shard_instruction in tqdm.tqdm(
372438
shard_instructions,
373439
unit=' shards',
374440
desc=f'Shards in {os.fspath(dataset_dir)}',
375441
):
376-
shard_instruction.convert()
442+
result = shard_instruction.convert()
443+
if result is not None:
444+
converted_shards += 1
445+
logging.info('Converting metadata in %s.', dataset_dir)
446+
convert_metadata(
447+
in_dir=dataset_dir,
448+
info=info,
449+
out_path=out_dir,
450+
convert_config=convert_config,
451+
num_converted_shards=converted_shards,
452+
)
377453

378454

379455
def _remove_incomplete_files(path: epath.Path) -> None:
@@ -521,21 +597,9 @@ def _process_get_infos(from_to_dir):
521597
out_dir=out_dir,
522598
)
523599

524-
logging.info('All shards have been converted. Now converting metadata.')
525-
for dataset_dir, info in tqdm.tqdm(
526-
found_dataset_versions.items(), unit=' datasets'
527-
):
528-
out_dir = from_to_dirs[dataset_dir]
529-
logging.info('Converting metadata in %s.', dataset_dir)
530-
convert_metadata(
531-
in_dir=dataset_dir,
532-
info=info,
533-
out_file_format=convert_config.out_file_format,
534-
out_path=out_dir,
535-
)
536-
537600
logging.info(
538-
'All metadata has been converted. Now removing incomplete files.'
601+
'All metadata and shards have been converted. Now removing incomplete'
602+
' files.'
539603
)
540604
for out_dir in from_to_dirs.values():
541605
logging.info('Removing incomplete files in %s.', out_dir)

tensorflow_datasets/scripts/cli/convert_format_utils_test.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,14 @@ def test_record_source_dataset(tmpdir):
184184
in_data_dir.mkdir(parents=True)
185185
out_data_dir.mkdir(parents=True)
186186
info = _create_dataset_info(in_data_dir)
187+
convert_config = convert_format_utils.ConvertConfig(
188+
out_file_format=file_adapters.FileFormat.RIEGELI
189+
)
187190
convert_format_utils.convert_metadata(
188191
in_dir=in_data_dir,
189192
out_path=out_data_dir,
190193
info=info,
191-
out_file_format=file_adapters.FileFormat.RIEGELI,
194+
convert_config=convert_config,
192195
)
193196
converted_info = dataset_info_lib.read_proto_from_builder_dir(out_data_dir)
194197
assert converted_info.name == info.name
@@ -212,12 +215,15 @@ def test_convert_metadata_add_to_existing(tmpdir):
212215
# Create a converted shard in the input directory.
213216
converted_shard = in_data_dir / 'a-train.riegeli-00000-of-00001'
214217
converted_shard.touch()
218+
convert_config = convert_format_utils.ConvertConfig(
219+
out_file_format=file_adapters.FileFormat.RIEGELI
220+
)
215221

216222
convert_format_utils.convert_metadata(
217223
in_dir=in_data_dir,
218224
out_path=in_data_dir,
219225
info=info,
220-
out_file_format=file_adapters.FileFormat.RIEGELI,
226+
convert_config=convert_config,
221227
)
222228
converted_info = dataset_info_lib.read_proto_from_builder_dir(in_data_dir)
223229
assert converted_info.file_format == file_adapters.FileFormat.TFRECORD.value
@@ -232,11 +238,14 @@ def test_convert_metadata_missing_shards(tmpdir):
232238
info = _create_dataset_info(
233239
in_data_dir, split_lengths={'train': 2, 'test': 1}
234240
)
241+
convert_config = convert_format_utils.ConvertConfig(
242+
out_file_format=file_adapters.FileFormat.RIEGELI
243+
)
235244
convert_format_utils.convert_metadata(
236245
in_dir=in_data_dir,
237246
out_path=in_data_dir,
238247
info=info,
239-
out_file_format=file_adapters.FileFormat.RIEGELI,
248+
convert_config=convert_config,
240249
)
241250
converted_info = dataset_info_lib.read_proto_from_builder_dir(in_data_dir)
242251
assert converted_info.file_format == file_adapters.FileFormat.TFRECORD.value

0 commit comments

Comments
 (0)