Skip to content

Commit 68244d3

Browse files
author
The TensorFlow Datasets Authors
committed
Add conversion functions to convert_format_utils in TFDS.
PiperOrigin-RevId: 670541104
1 parent 73f926a commit 68244d3

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

tensorflow_datasets/scripts/cli/convert_format_utils.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import functools
2121
import os
2222
import re
23-
from typing import Type
23+
from typing import Callable, Type
2424

2525
from etils import epy
2626

@@ -31,6 +31,7 @@
3131
from absl import logging
3232
import apache_beam as beam
3333
from etils import epath
34+
import tensorflow as tf
3435
from tensorflow_datasets.core import constants
3536
from tensorflow_datasets.core import dataset_info
3637
from tensorflow_datasets.core import file_adapters
@@ -44,6 +45,9 @@
4445
# pylint: enable=g-import-not-at-top
4546

4647

48+
ConvertFn = Callable[[tf.train.Example], bytes]
49+
50+
4751
@dataclasses.dataclass(frozen=True)
4852
class ShardInstruction:
4953
"""Instruction for how one single shard should be converted."""
@@ -52,12 +56,16 @@ class ShardInstruction:
5256
in_file_adapter: Type[file_adapters.FileAdapter]
5357
out_path: epath.Path
5458
out_file_adapter: Type[file_adapters.FileAdapter]
59+
convert_fn: ConvertFn | None = None
5560

5661
def convert(self) -> None:
5762
def read_in() -> Iterator[type_utils.KeySerializedExample]:
5863
in_dataset = self.in_file_adapter.make_tf_data(filename=self.in_path)
5964
for i, row in enumerate(in_dataset):
60-
yield i, row.numpy()
65+
if self.convert_fn is not None:
66+
yield i, self.convert_fn(row)
67+
else:
68+
yield i, row.numpy()
6169

6270
with py_utils.incomplete_file(self.out_path) as tmp_file:
6371
self.out_file_adapter.write_examples(path=tmp_file, iterator=read_in())
@@ -69,6 +77,7 @@ def _shard_instructions_for_split(
6977
out_path: epath.Path,
7078
in_file_adapter: Type[file_adapters.FileAdapter],
7179
out_file_adapter: Type[file_adapters.FileAdapter],
80+
convert_fn: ConvertFn | None = None,
7281
) -> list[ShardInstruction]:
7382
"""Returns shard instructions for the given split."""
7483

@@ -101,6 +110,7 @@ def _shard_instructions_for_split(
101110
in_file_adapter=in_file_adapter,
102111
out_path=out_path,
103112
out_file_adapter=out_file_adapter,
113+
convert_fn=convert_fn,
104114
)
105115
)
106116
return instructions
@@ -110,6 +120,7 @@ def get_all_shard_instructions(
110120
info: dataset_info.DatasetInfo,
111121
out_file_format: file_adapters.FileFormat,
112122
out_path: epath.Path,
123+
convert_fn: ConvertFn | None = None,
113124
) -> list[ShardInstruction]:
114125
"""Returns all shard instructions for the given dataset info."""
115126
in_file_adapter = file_adapters.ADAPTER_FOR_FORMAT[info.file_format]
@@ -123,6 +134,7 @@ def get_all_shard_instructions(
123134
out_path=out_path,
124135
in_file_adapter=in_file_adapter,
125136
out_file_adapter=out_file_adapter,
137+
convert_fn=convert_fn,
126138
)
127139
)
128140
return shard_instructions
@@ -203,6 +215,7 @@ def _convert_dataset(
203215
out_file_format: file_adapters.FileFormat,
204216
overwrite: bool = False,
205217
pipeline: beam.Pipeline | None = None,
218+
convert_fn: ConvertFn | None = None,
206219
) -> None:
207220
"""Converts a single dataset version to the given file format."""
208221
logging.info(
@@ -220,6 +233,7 @@ def _convert_dataset(
220233
info=info,
221234
out_file_format=out_file_format,
222235
out_path=out_dir,
236+
convert_fn=convert_fn,
223237
)
224238

225239
if not shard_instructions:
@@ -261,6 +275,7 @@ def _convert_dataset_dirs(
261275
overwrite: bool = False,
262276
use_beam: bool = False,
263277
num_workers: int = 8,
278+
convert_fn: ConvertFn | None = None,
264279
) -> None:
265280
"""Converts all datasets in the given `from_to_dirs` parameter.
266281
@@ -271,6 +286,8 @@ def _convert_dataset_dirs(
271286
overwrite: whether to overwrite the to_dirs if they exist.
272287
use_beam: whether to use Beam to convert the datasets.
273288
num_workers: number of workers to use if `use_beam` is `False`.
289+
convert_fn: optional function to convert TF examples into whatever is
290+
desired.
274291
"""
275292
logging.info('Converting %d datasets.', len(from_to_dirs))
276293

@@ -310,6 +327,7 @@ def _convert_dataset_dirs(
310327
_convert_dataset,
311328
out_file_format=out_file_format,
312329
overwrite=overwrite,
330+
convert_fn=convert_fn,
313331
)
314332

315333
# First convert all shards (with or without Beam), then convert the metadata.
@@ -392,6 +410,7 @@ def convert_root_data_dir(
392410
use_beam: bool,
393411
overwrite: bool = False,
394412
num_workers: int = 8,
413+
convert_fn: ConvertFn | None = None,
395414
) -> None:
396415
"""Converts all datasets found in the given dataset dir.
397416
@@ -406,6 +425,8 @@ def convert_root_data_dir(
406425
use_beam: whether to use Beam to convert datasets. Useful for big datasets.
407426
overwrite: whether to overwrite folders in `out_dir` if they already exist.
408427
num_workers: number of workers to use if `use_beam` is `False`.
428+
convert_fn: optional function to convert TF examples into whatever is
429+
desired.
409430
"""
410431
root_data_dir = epath.Path(root_data_dir)
411432
out_path = epath.Path(out_dir) if out_dir is not None else None
@@ -432,6 +453,7 @@ def convert_root_data_dir(
432453
overwrite=overwrite,
433454
use_beam=use_beam,
434455
num_workers=num_workers,
456+
convert_fn=convert_fn,
435457
)
436458

437459

@@ -453,6 +475,7 @@ def convert_dataset_dir(
453475
use_beam: bool,
454476
overwrite: bool = False,
455477
num_workers: int = 8,
478+
convert_fn: ConvertFn | None = None,
456479
) -> None:
457480
"""Converts all datasets found in the given dataset dir.
458481
@@ -467,6 +490,8 @@ def convert_dataset_dir(
467490
use_beam: whether to use Beam to convert datasets. Useful for big datasets.
468491
overwrite: whether to overwrite folders in `out_dir` if they already exist.
469492
num_workers: number of workers to use if `use_beam` is `False`.
493+
convert_fn: optional function to convert TF examples into whatever is
494+
desired.
470495
"""
471496
dataset_dir = epath.Path(dataset_dir)
472497
out_path = epath.Path(out_dir) if out_dir is not None else None
@@ -495,6 +520,7 @@ def convert_dataset_dir(
495520
overwrite=overwrite,
496521
use_beam=use_beam,
497522
num_workers=num_workers,
523+
convert_fn=convert_fn,
498524
)
499525

500526

@@ -509,6 +535,7 @@ def convert_dataset(
509535
overwrite: bool = False,
510536
use_beam: bool = False,
511537
num_workers: int = 8,
538+
convert_fn: ConvertFn | None = None,
512539
) -> None:
513540
"""Convert a dataset from one file format to another format.
514541
@@ -534,6 +561,8 @@ def convert_dataset(
534561
num_workers: number of workers to use when not using Beam. If `use_beam` is
535562
set, this flag is ignored. If `num_workers=1`, the conversion will be done
536563
sequentially.
564+
convert_fn: optional function to convert TF examples into whatever is
565+
desired.
537566
"""
538567
if (
539568
root_data_dir is None
@@ -588,6 +617,7 @@ def convert_dataset(
588617
overwrite=overwrite,
589618
use_beam=use_beam,
590619
num_workers=num_workers,
620+
convert_fn=convert_fn,
591621
)
592622
else:
593623
raise ValueError(

0 commit comments

Comments
 (0)