20
20
import functools
21
21
import os
22
22
import re
23
- from typing import Type
23
+ from typing import Callable , Type
24
24
25
25
from etils import epy
26
26
31
31
from absl import logging
32
32
import apache_beam as beam
33
33
from etils import epath
34
+ import tensorflow as tf
34
35
from tensorflow_datasets .core import constants
35
36
from tensorflow_datasets .core import dataset_info
36
37
from tensorflow_datasets .core import file_adapters
44
45
# pylint: enable=g-import-not-at-top
45
46
46
47
48
+ ConvertFn = Callable [[tf .train .Example ], bytes ]
49
+
50
+
47
51
@dataclasses .dataclass (frozen = True )
48
52
class ShardInstruction :
49
53
"""Instruction for how one single shard should be converted."""
@@ -52,12 +56,16 @@ class ShardInstruction:
52
56
in_file_adapter : Type [file_adapters .FileAdapter ]
53
57
out_path : epath .Path
54
58
out_file_adapter : Type [file_adapters .FileAdapter ]
59
+ convert_fn : ConvertFn | None = None
55
60
56
61
def convert (self ) -> None :
57
62
def read_in () -> Iterator [type_utils .KeySerializedExample ]:
58
63
in_dataset = self .in_file_adapter .make_tf_data (filename = self .in_path )
59
64
for i , row in enumerate (in_dataset ):
60
- yield i , row .numpy ()
65
+ if self .convert_fn is not None :
66
+ yield i , self .convert_fn (row )
67
+ else :
68
+ yield i , row .numpy ()
61
69
62
70
with py_utils .incomplete_file (self .out_path ) as tmp_file :
63
71
self .out_file_adapter .write_examples (path = tmp_file , iterator = read_in ())
@@ -69,6 +77,7 @@ def _shard_instructions_for_split(
69
77
out_path : epath .Path ,
70
78
in_file_adapter : Type [file_adapters .FileAdapter ],
71
79
out_file_adapter : Type [file_adapters .FileAdapter ],
80
+ convert_fn : ConvertFn | None = None ,
72
81
) -> list [ShardInstruction ]:
73
82
"""Returns shard instructions for the given split."""
74
83
@@ -101,6 +110,7 @@ def _shard_instructions_for_split(
101
110
in_file_adapter = in_file_adapter ,
102
111
out_path = out_path ,
103
112
out_file_adapter = out_file_adapter ,
113
+ convert_fn = convert_fn ,
104
114
)
105
115
)
106
116
return instructions
@@ -110,6 +120,7 @@ def get_all_shard_instructions(
110
120
info : dataset_info .DatasetInfo ,
111
121
out_file_format : file_adapters .FileFormat ,
112
122
out_path : epath .Path ,
123
+ convert_fn : ConvertFn | None = None ,
113
124
) -> list [ShardInstruction ]:
114
125
"""Returns all shard instructions for the given dataset info."""
115
126
in_file_adapter = file_adapters .ADAPTER_FOR_FORMAT [info .file_format ]
@@ -123,6 +134,7 @@ def get_all_shard_instructions(
123
134
out_path = out_path ,
124
135
in_file_adapter = in_file_adapter ,
125
136
out_file_adapter = out_file_adapter ,
137
+ convert_fn = convert_fn ,
126
138
)
127
139
)
128
140
return shard_instructions
@@ -203,6 +215,7 @@ def _convert_dataset(
203
215
out_file_format : file_adapters .FileFormat ,
204
216
overwrite : bool = False ,
205
217
pipeline : beam .Pipeline | None = None ,
218
+ convert_fn : ConvertFn | None = None ,
206
219
) -> None :
207
220
"""Converts a single dataset version to the given file format."""
208
221
logging .info (
@@ -220,6 +233,7 @@ def _convert_dataset(
220
233
info = info ,
221
234
out_file_format = out_file_format ,
222
235
out_path = out_dir ,
236
+ convert_fn = convert_fn ,
223
237
)
224
238
225
239
if not shard_instructions :
@@ -261,6 +275,7 @@ def _convert_dataset_dirs(
261
275
overwrite : bool = False ,
262
276
use_beam : bool = False ,
263
277
num_workers : int = 8 ,
278
+ convert_fn : ConvertFn | None = None ,
264
279
) -> None :
265
280
"""Converts all datasets in the given `from_to_dirs` parameter.
266
281
@@ -271,6 +286,8 @@ def _convert_dataset_dirs(
271
286
overwrite: whether to overwrite the to_dirs if they exist.
272
287
use_beam: whether to use Beam to convert the datasets.
273
288
num_workers: number of workers to use if `use_beam` is `False`.
289
+ convert_fn: optional function to convert TF examples into whatever is
290
+ desired.
274
291
"""
275
292
logging .info ('Converting %d datasets.' , len (from_to_dirs ))
276
293
@@ -310,6 +327,7 @@ def _convert_dataset_dirs(
310
327
_convert_dataset ,
311
328
out_file_format = out_file_format ,
312
329
overwrite = overwrite ,
330
+ convert_fn = convert_fn ,
313
331
)
314
332
315
333
# First convert all shards (with or without Beam), then convert the metadata.
@@ -392,6 +410,7 @@ def convert_root_data_dir(
392
410
use_beam : bool ,
393
411
overwrite : bool = False ,
394
412
num_workers : int = 8 ,
413
+ convert_fn : ConvertFn | None = None ,
395
414
) -> None :
396
415
"""Converts all datasets found in the given dataset dir.
397
416
@@ -406,6 +425,8 @@ def convert_root_data_dir(
406
425
use_beam: whether to use Beam to convert datasets. Useful for big datasets.
407
426
overwrite: whether to overwrite folders in `out_dir` if they already exist.
408
427
num_workers: number of workers to use if `use_beam` is `False`.
428
+ convert_fn: optional function to convert TF examples into whatever is
429
+ desired.
409
430
"""
410
431
root_data_dir = epath .Path (root_data_dir )
411
432
out_path = epath .Path (out_dir ) if out_dir is not None else None
@@ -432,6 +453,7 @@ def convert_root_data_dir(
432
453
overwrite = overwrite ,
433
454
use_beam = use_beam ,
434
455
num_workers = num_workers ,
456
+ convert_fn = convert_fn ,
435
457
)
436
458
437
459
@@ -453,6 +475,7 @@ def convert_dataset_dir(
453
475
use_beam : bool ,
454
476
overwrite : bool = False ,
455
477
num_workers : int = 8 ,
478
+ convert_fn : ConvertFn | None = None ,
456
479
) -> None :
457
480
"""Converts all datasets found in the given dataset dir.
458
481
@@ -467,6 +490,8 @@ def convert_dataset_dir(
467
490
use_beam: whether to use Beam to convert datasets. Useful for big datasets.
468
491
overwrite: whether to overwrite folders in `out_dir` if they already exist.
469
492
num_workers: number of workers to use if `use_beam` is `False`.
493
+ convert_fn: optional function to convert TF examples into whatever is
494
+ desired.
470
495
"""
471
496
dataset_dir = epath .Path (dataset_dir )
472
497
out_path = epath .Path (out_dir ) if out_dir is not None else None
@@ -495,6 +520,7 @@ def convert_dataset_dir(
495
520
overwrite = overwrite ,
496
521
use_beam = use_beam ,
497
522
num_workers = num_workers ,
523
+ convert_fn = convert_fn ,
498
524
)
499
525
500
526
@@ -509,6 +535,7 @@ def convert_dataset(
509
535
overwrite : bool = False ,
510
536
use_beam : bool = False ,
511
537
num_workers : int = 8 ,
538
+ convert_fn : ConvertFn | None = None ,
512
539
) -> None :
513
540
"""Convert a dataset from one file format to another format.
514
541
@@ -534,6 +561,8 @@ def convert_dataset(
534
561
num_workers: number of workers to use when not using Beam. If `use_beam` is
535
562
set, this flag is ignored. If `num_workers=1`, the conversion will be done
536
563
sequentially.
564
+ convert_fn: optional function to convert TF examples into whatever is
565
+ desired.
537
566
"""
538
567
if (
539
568
root_data_dir is None
@@ -588,6 +617,7 @@ def convert_dataset(
588
617
overwrite = overwrite ,
589
618
use_beam = use_beam ,
590
619
num_workers = num_workers ,
620
+ convert_fn = convert_fn ,
591
621
)
592
622
else :
593
623
raise ValueError (
0 commit comments