Skip to content

Commit c8ee73c

Browse files
marcenacpThe TensorFlow Datasets Authors
authored andcommitted
Use etree instead of a custom module for parallel_map.
PiperOrigin-RevId: 618901084
1 parent d7c97ea commit c8ee73c

22 files changed

+64
-178
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,12 @@
5454
from tensorflow_datasets.core.utils import file_utils
5555
from tensorflow_datasets.core.utils import gcs_utils
5656
from tensorflow_datasets.core.utils import read_config as read_config_lib
57-
from tensorflow_datasets.core.utils import tree_utils
5857
from tensorflow_datasets.core.utils import type_utils
5958
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
6059
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
6160
import termcolor
61+
import tree
62+
6263

6364
ListOrTreeOrElem = type_utils.ListOrTreeOrElem
6465
Tree = type_utils.Tree
@@ -794,7 +795,7 @@ def build_single_data_source(
794795
f" {' or '.join(args)}?"
795796
)
796797

797-
all_ds = tree_utils.map_structure(build_single_data_source, split)
798+
all_ds = tree.map_structure(build_single_data_source, split)
798799
return all_ds
799800

800801
@tfds_logging.as_dataset()
@@ -908,7 +909,7 @@ def as_dataset(
908909
read_config=read_config,
909910
as_supervised=as_supervised,
910911
)
911-
all_ds = tree_utils.map_structure(build_single_dataset, split)
912+
all_ds = tree.map_structure(build_single_dataset, split)
912913
return all_ds
913914

914915
def _build_single_dataset(
@@ -961,7 +962,7 @@ def lookup_nest(features: Dict[str, Any]) -> Tuple[Any, ...]:
961962
Returns:
962963
A tuple with elements structured according to `supervised_keys`
963964
"""
964-
return tree_utils.map_structure(
965+
return tree.map_structure(
965966
lambda key: features[key], self.info.supervised_keys
966967
)
967968

tensorflow_datasets/core/dataset_metadata.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import functools
2020

2121
from etils import epath
22+
from etils import etree
2223
from tensorflow_datasets.core import constants
23-
from tensorflow_datasets.core import utils
2424
from tensorflow_datasets.core.utils import resource_utils
2525

2626

@@ -68,13 +68,11 @@ def valid_tags() -> list[str]:
6868

6969
def valid_tags_with_comments() -> str:
7070
"""Returns valid tags (one per line) with comments."""
71-
return "\n".join(
72-
[
73-
line
74-
for line in _get_valid_tags_text().split("\n")
75-
if not line.startswith("#")
76-
]
77-
)
71+
return "\n".join([
72+
line
73+
for line in _get_valid_tags_text().split("\n")
74+
if not line.startswith("#")
75+
])
7876

7977

8078
@functools.lru_cache(maxsize=256)
@@ -103,6 +101,4 @@ def _read_files(path: epath.Path) -> dict[str, str]:
103101
for inode in path.iterdir():
104102
if inode.name in _METADATA_FILES:
105103
name2path[inode.name] = path.joinpath(inode.name)
106-
return utils.tree.parallel_map(
107-
lambda f: f.read_text(encoding="utf-8"), name2path
108-
)
104+
return etree.parallel_map(lambda f: f.read_text(encoding="utf-8"), name2path)

tensorflow_datasets/core/dataset_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
import numpy as np
2828
from tensorflow_datasets.core import logging as tfds_logging
2929
from tensorflow_datasets.core import utils
30-
from tensorflow_datasets.core.utils import tree_utils
3130
from tensorflow_datasets.core.utils import type_utils
3231
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
32+
import tree
3333

3434
Tree = type_utils.Tree
3535
Tensor = type_utils.Tensor
@@ -80,7 +80,7 @@ def element_spec(self) -> Tree[enp.ArraySpec]:
8080

8181
def _eager_dataset_iterator(ds: tf.data.Dataset) -> Iterator[NumpyElem]:
8282
for elem in ds:
83-
yield tree_utils.map_structure(_elem_to_numpy_eager, elem)
83+
yield tree.map_structure(_elem_to_numpy_eager, elem)
8484

8585

8686
def _graph_dataset_iterator(ds_iter, graph: tf.Graph) -> Iterator[NumpyElem]:
@@ -105,7 +105,7 @@ def _assert_ds_types(nested_ds: Tree[TensorflowElem]) -> None:
105105
isinstance(el, (tf.Tensor, tf.RaggedTensor))
106106
or isinstance(el, tf.data.Dataset)
107107
):
108-
nested_types = tree_utils.map_structure(type, nested_ds)
108+
nested_types = tree.map_structure(type, nested_ds)
109109
raise TypeError(
110110
'Arguments to as_numpy must be tf.Tensors or tf.data.Datasets. '
111111
f'Got: {nested_types}.'
@@ -197,7 +197,7 @@ def as_numpy(dataset: Tree[TensorflowElem]) -> Tree[NumpyElem]:
197197
"""
198198
_assert_ds_types(dataset)
199199
if tf.executing_eagerly():
200-
return tree_utils.map_structure(_elem_to_numpy_eager, dataset)
200+
return tree.map_structure(_elem_to_numpy_eager, dataset)
201201
else:
202202
return _nested_to_numpy_graph(dataset)
203203

tensorflow_datasets/core/decode/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import abc
1919
import functools
2020

21-
from tensorflow_datasets.core.utils import tree_utils
2221
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
22+
import tree
2323

2424

2525
class Decoder(abc.ABC):
@@ -69,7 +69,7 @@ def setup(self, *, feature):
6969
def dtype(self):
7070
"""Returns the `dtype` after decoding."""
7171
tensor_info = self.feature.get_tensor_info()
72-
return tree_utils.map_structure(lambda t: t.dtype, tensor_info)
72+
return tree.map_structure(lambda t: t.dtype, tensor_info)
7373

7474
@abc.abstractmethod
7575
def decode_example(self, serialized_example):
@@ -135,7 +135,7 @@ class SkipDecoding(Decoder):
135135
@property
136136
def dtype(self):
137137
tensor_info = self.feature.get_serialized_info()
138-
return tree_utils.map_structure(lambda t: t.dtype, tensor_info)
138+
return tree.map_structure(lambda t: t.dtype, tensor_info)
139139

140140
def decode_example(self, serialized_example):
141141
"""Forward the serialized feature field."""

tensorflow_datasets/core/download/download_manager.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
from tensorflow_datasets.core.download import resource as resource_lib
3636
from tensorflow_datasets.core.download import util
3737
from tensorflow_datasets.core.utils import shard_utils
38-
from tensorflow_datasets.core.utils import tree_utils
3938
from tensorflow_datasets.core.utils import type_utils
39+
import tree
4040

4141

4242
# pylint: disable=logging-fstring-interpolation
@@ -825,10 +825,6 @@ def _read_url_info(url_path: epath.PathLike) -> checksums.UrlInfo:
825825

826826
def _map_promise(map_fn, all_inputs):
827827
"""Map the function into each element and resolve the promise."""
828-
all_promises = tree_utils.map_structure(
829-
map_fn, all_inputs
830-
) # Apply the function
831-
res = tree_utils.map_structure(
832-
lambda p: p.get(), all_promises
833-
) # Wait promises
828+
all_promises = tree.map_structure(map_fn, all_inputs) # Apply the function
829+
res = tree.map_structure(lambda p: p.get(), all_promises) # Wait promises
834830
return res

tensorflow_datasets/core/download/download_manager_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,14 @@
2525
from absl.testing import parameterized
2626
from etils import epath
2727
import promise
28-
2928
import tensorflow as tf
30-
3129
from tensorflow_datasets import testing
3230
from tensorflow_datasets.core.download import checksums as checksums_lib
3331
from tensorflow_datasets.core.download import download_manager as dm
3432
from tensorflow_datasets.core.download import downloader
3533
from tensorflow_datasets.core.download import extractor
3634
from tensorflow_datasets.core.download import resource as resource_lib
37-
from tensorflow_datasets.core.utils import tree_utils
35+
import tree
3836

3937
ZIP = resource_lib.ExtractMethod.ZIP
4038
TAR = resource_lib.ExtractMethod.TAR
@@ -46,7 +44,7 @@ def _sha256(str_):
4644

4745

4846
def _as_path(nested_paths):
49-
return tree_utils.map_structure(epath.Path, nested_paths)
47+
return tree.map_structure(epath.Path, nested_paths)
5048

5149

5250
def _info_path(path):

tensorflow_datasets/core/features/dataset_feature.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
"""Dataset feature for nested datasets."""
17+
1718
from __future__ import annotations
1819

1920
import dataclasses
@@ -26,9 +27,9 @@
2627
from tensorflow_datasets.core.features import tensor_feature
2728
from tensorflow_datasets.core.features import top_level_feature
2829
from tensorflow_datasets.core.utils import py_utils
29-
from tensorflow_datasets.core.utils import tree_utils
3030
from tensorflow_datasets.core.utils import type_utils
3131
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
32+
import tree
3233

3334

3435
@dataclasses.dataclass(frozen=True)
@@ -120,7 +121,7 @@ def get_tensor_info(self):
120121
"""Shape of one element of the dataset."""
121122
# Add the dataset level
122123
tensor_info = self._feature.get_tensor_info()
123-
return tree_utils.map_structure(_add_dataset_lvl, tensor_info)
124+
return tree.map_structure(_add_dataset_lvl, tensor_info)
124125

125126
def get_tensor_spec(self) -> tf.data.DatasetSpec:
126127
return tf.data.DatasetSpec(element_spec=self._feature.get_tensor_spec())
@@ -129,7 +130,7 @@ def get_tensor_spec(self) -> tf.data.DatasetSpec:
129130
def get_serialized_info(self):
130131
# Add the dataset level and the number of elements in the dataset
131132
tensor_info = super().get_serialized_info()
132-
return tree_utils.map_structure(_add_dataset_lvl, tensor_info)
133+
return tree.map_structure(_add_dataset_lvl, tensor_info)
133134

134135
def encode_example(
135136
self,
@@ -146,7 +147,7 @@ def encode_example(
146147

147148
# Empty datasets return empty arrays
148149
if not ds_elements:
149-
return tree_utils.map_structure(
150+
return tree.map_structure(
150151
sequence_feature.build_empty_np, self.get_serialized_info()
151152
)
152153

tensorflow_datasets/core/features/feature.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@
3838
from tensorflow_datasets.core.utils import np_utils
3939
from tensorflow_datasets.core.utils import py_utils
4040
from tensorflow_datasets.core.utils import tf_utils
41-
from tensorflow_datasets.core.utils import tree_utils
4241
from tensorflow_datasets.core.utils import type_utils
4342
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
4443
from tensorflow_datasets.core.utils.lazy_imports_utils import tf_agents
44+
import tree
4545

4646
from google.protobuf import descriptor
4747
from google.protobuf import json_format
@@ -364,14 +364,14 @@ def get_tensor_spec(self) -> TreeDict[tf.TensorSpec]:
364364
of the dataset. For example, currently this method does not support
365365
RaggedTensorSpec.
366366
"""
367-
return tree_utils.map_structure(
367+
return tree.map_structure(
368368
lambda ti: ti.to_tensor_spec(), self.get_tensor_info()
369369
)
370370

371371
@functools.cached_property
372372
def shape(self):
373373
"""Return the shape (or dict of shape) of this FeatureConnector."""
374-
return tree_utils.map_structure(lambda t: t.shape, self.get_tensor_info())
374+
return tree.map_structure(lambda t: t.shape, self.get_tensor_info())
375375

376376
@functools.cached_property
377377
def dtype(self) -> TreeDict[tf.dtypes.DType]:
@@ -381,9 +381,7 @@ def dtype(self) -> TreeDict[tf.dtypes.DType]:
381381

382382
@functools.cached_property
383383
def np_dtype(self) -> TreeDict[np.dtype]:
384-
return tree_utils.map_structure(
385-
lambda t: t.np_dtype, self.get_tensor_info()
386-
)
384+
return tree.map_structure(lambda t: t.np_dtype, self.get_tensor_info())
387385

388386
# For backwards compatibility: now it is named np_dtype.
389387
@functools.cached_property
@@ -397,7 +395,7 @@ def convert_to_tensorflow(value):
397395
return tf.dtypes.as_dtype(value)
398396
return value.tf_dtype
399397

400-
return tree_utils.map_structure(convert_to_tensorflow, self.np_dtype)
398+
return tree.map_structure(convert_to_tensorflow, self.np_dtype)
401399

402400
@classmethod
403401
def cls_from_name(cls, python_class_name: str) -> Type['FeatureConnector']:

tensorflow_datasets/core/features/sequence_feature.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from tensorflow_datasets.core.features import top_level_feature
2828
from tensorflow_datasets.core.proto import feature_pb2
2929
from tensorflow_datasets.core.utils import py_utils
30-
from tensorflow_datasets.core.utils import tree_utils
3130
from tensorflow_datasets.core.utils import type_utils
31+
import tree
3232

3333
Json = type_utils.Json
3434

@@ -121,14 +121,14 @@ def get_tensor_info(self):
121121
"""See base class for details."""
122122
# Add the additional length dimension to every shape
123123
tensor_info = self._feature.get_tensor_info()
124-
return tree_utils.map_structure(self._add_length_dim, tensor_info)
124+
return tree.map_structure(self._add_length_dim, tensor_info)
125125

126126
@py_utils.memoize()
127127
def get_serialized_info(self):
128128
"""See base class for details."""
129129
# Add the additional length dimension to every serialized features
130130
tensor_info = self._feature.get_serialized_info()
131-
return tree_utils.map_structure(self._add_length_dim, tensor_info)
131+
return tree.map_structure(self._add_length_dim, tensor_info)
132132

133133
def encode_example(self, example_dict):
134134
# Convert nested dict[list] into list[nested dict]
@@ -143,9 +143,7 @@ def encode_example(self, example_dict):
143143

144144
# Empty sequences return empty arrays
145145
if not sequence_elements:
146-
return tree_utils.map_structure(
147-
build_empty_np, self.get_serialized_info()
148-
)
146+
return tree.map_structure(build_empty_np, self.get_serialized_info())
149147

150148
# Encode each individual elements
151149
sequence_elements = [

tensorflow_datasets/core/folder_dataset/compute_split_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
from typing import Optional, Type, Union, cast
2626

2727
from etils import epath
28+
from etils import etree
2829
from tensorflow_datasets.core import file_adapters
2930
from tensorflow_datasets.core import lazy_imports_lib
3031
from tensorflow_datasets.core import naming
3132
from tensorflow_datasets.core import splits as split_lib
32-
from tensorflow_datasets.core import utils
3333
from tensorflow_datasets.core.proto import dataset_info_pb2
3434
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
3535
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_data_source
@@ -230,14 +230,14 @@ def _compute_split_statistics(
230230
# Compute all shard info in parallel
231231
split_to_shard_infos = cast(
232232
Mapping[str, Sequence[_ShardInfo]],
233-
utils.tree.parallel_map(
233+
etree.parallel_map(
234234
functools.partial(
235235
_process_shard,
236236
data_dir=filename_template.data_dir,
237237
adapter=adapter,
238238
),
239239
split_files,
240-
report_progress=True,
240+
progress_bar=True,
241241
),
242242
)
243243
# Create the SplitInfo for all splits

0 commit comments

Comments
 (0)