Skip to content

Commit 6b4b728

Browse files
author
The TensorFlow Datasets Authors
committed
Parallelize getting dataset infos in convert_format.py.
PiperOrigin-RevId: 686925679
1 parent 40d6b37 commit 6b4b728

File tree

1 file changed

+33
-10
lines changed

1 file changed

+33
-10
lines changed

tensorflow_datasets/scripts/cli/convert_format_utils.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -345,16 +345,39 @@ def _convert_dataset_dirs(
345345
logging.info('Converting %d datasets.', len(from_to_dirs))
346346

347347
found_dataset_versions: dict[epath.Path, dataset_info.DatasetInfo] = {}
348-
# TODO(weide) parallelize this, because it's slow for dirs with many datasets.
349-
for from_dir, to_dir in from_to_dirs.items():
350-
info = _get_info_for_dirs_to_convert(
351-
from_dir=from_dir,
352-
to_dir=to_dir,
353-
out_file_format=out_file_format,
354-
overwrite=overwrite,
355-
)
356-
if info is not None:
357-
found_dataset_versions[from_dir] = info
348+
349+
if num_workers > 1:
350+
351+
def _process_get_infos(from_to_dir):
352+
from_dir, to_dir = from_to_dir
353+
return from_dir, _get_info_for_dirs_to_convert(
354+
from_dir=from_dir,
355+
to_dir=to_dir,
356+
out_file_format=out_file_format,
357+
overwrite=overwrite,
358+
)
359+
360+
with concurrent.futures.ThreadPoolExecutor(
361+
max_workers=num_workers
362+
) as executor:
363+
for from_dir, info in executor.map(
364+
_process_get_infos,
365+
from_to_dirs.items(),
366+
):
367+
if info is not None:
368+
found_dataset_versions[from_dir] = info
369+
else:
370+
for from_dir, to_dir in tqdm.tqdm(
371+
from_to_dirs.items(), unit=' directories'
372+
):
373+
info = _get_info_for_dirs_to_convert(
374+
from_dir=from_dir,
375+
to_dir=to_dir,
376+
out_file_format=out_file_format,
377+
overwrite=overwrite,
378+
)
379+
if info is not None:
380+
found_dataset_versions[from_dir] = info
358381

359382
convert_dataset_fn = functools.partial(
360383
_convert_dataset,

0 commit comments

Comments
 (0)