Skip to content

Commit 7c40c08

Browse files
author
The TensorFlow Datasets Authors
committed
Add support for multidimensional arrays in the CroissantBuilder.
PiperOrigin-RevId: 730986386
1 parent b7d6e96 commit 7c40c08

File tree

2 files changed

+107
-0
lines changed

2 files changed

+107
-0
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from tensorflow_datasets.core.features import features_dict
5353
from tensorflow_datasets.core.features import image_feature
5454
from tensorflow_datasets.core.features import sequence_feature
55+
from tensorflow_datasets.core.features import tensor_feature
5556
from tensorflow_datasets.core.features import text_feature
5657
from tensorflow_datasets.core.utils import conversion_utils
5758
from tensorflow_datasets.core.utils import croissant_utils
@@ -75,6 +76,51 @@ def _strip_record_set_prefix(
7576
}
7677

7778

79+
def array_datatype_converter(
80+
feature: type_utils.TfdsDType | feature_lib.FeatureConnector | None,
81+
field: mlc.Field,
82+
int_dtype: type_utils.TfdsDType = np.int64,
83+
float_dtype: type_utils.TfdsDType = np.float32,
84+
):
85+
"""Includes the given feature in a sequence or tensor feature.
86+
87+
Single-dimensional arrays are converted to sequences. Multi-dimensional arrays
88+
with unknown dimensions, or with non-native dtypes are converted to sequences
89+
of sequences. Otherwise, they are converted to tensors.
90+
91+
Args:
92+
feature: The inner feature to include in a sequence or tensor feature.
93+
field: The mlc.Field object.
94+
int_dtype: The dtype to use for TFDS integer features. Defaults to np.int64.
95+
float_dtype: The dtype to use for TFDS float features. Defaults to
96+
np.float32.
97+
98+
Returns:
99+
A sequence or tensor feature including the inner feature.
100+
"""
101+
dtype_mapping = {
102+
int: int_dtype,
103+
float: float_dtype,
104+
bool: np.bool_,
105+
bytes: np.str_,
106+
}
107+
dtype = dtype_mapping.get(field.data_type, None)
108+
if len(field.array_shape_tuple) == 1:
109+
return sequence_feature.Sequence(feature, doc=field.description)
110+
elif (-1 in field.array_shape_tuple) or (
111+
field.data_type not in dtype_mapping
112+
):
113+
for _ in range(len(field.array_shape_tuple)):
114+
feature = sequence_feature.Sequence(feature, doc=field.description)
115+
return feature
116+
else:
117+
return tensor_feature.Tensor(
118+
shape=field.array_shape_tuple,
119+
dtype=dtype,
120+
doc=field.description,
121+
)
122+
123+
78124
def datatype_converter(
79125
field: mlc.Field,
80126
int_dtype: type_utils.TfdsDType = np.int64,
@@ -133,6 +179,16 @@ def datatype_converter(
133179
else:
134180
raise ValueError(f'Unknown data type: {field_data_type}.')
135181

182+
if feature and field.is_array:
183+
feature = array_datatype_converter(
184+
feature=feature,
185+
field=field,
186+
int_dtype=int_dtype,
187+
float_dtype=float_dtype,
188+
)
189+
# If the field is repeated, we return a sequence feature. `field.repeated` is
190+
# deprecated starting from Croissant 1.1, but we still support it for
191+
# backwards compatibility.
136192
if feature and field.repeated:
137193
feature = sequence_feature.Sequence(feature, doc=field.description)
138194
return feature

tensorflow_datasets/core/dataset_builders/croissant_builder_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,57 @@ def test_complex_datatype_converter(field, feature_type, subfield_types):
161161
)
162162

163163

164+
def test_multidimensional_datatype_converter():
165+
field = mlc.Field(
166+
data_types=mlc.DataType.TEXT,
167+
description="Text feature",
168+
is_array=True,
169+
array_shape="2,2",
170+
)
171+
actual_feature = croissant_builder.datatype_converter(field)
172+
assert isinstance(actual_feature, tensor_feature.Tensor)
173+
assert actual_feature.shape == (2, 2)
174+
assert actual_feature.dtype == np.str_
175+
176+
177+
def test_multidimensional_datatype_converter_image_object():
178+
field = mlc.Field(
179+
data_types=mlc.DataType.IMAGE_OBJECT,
180+
description="Text feature",
181+
is_array=True,
182+
array_shape="2,2",
183+
)
184+
actual_feature = croissant_builder.datatype_converter(field)
185+
assert isinstance(actual_feature, sequence_feature.Sequence)
186+
assert isinstance(actual_feature.feature, sequence_feature.Sequence)
187+
assert isinstance(actual_feature.feature.feature, image_feature.Image)
188+
189+
190+
def test_multidimensional_datatype_converter_plain_list():
191+
field = mlc.Field(
192+
data_types=mlc.DataType.TEXT,
193+
description="Text feature",
194+
is_array=True,
195+
array_shape="-1",
196+
)
197+
actual_feature = croissant_builder.datatype_converter(field)
198+
assert isinstance(actual_feature, sequence_feature.Sequence)
199+
assert isinstance(actual_feature.feature, text_feature.Text)
200+
201+
202+
def test_multidimensional_datatype_converter_unknown_shape():
203+
field = mlc.Field(
204+
data_types=mlc.DataType.TEXT,
205+
description="Text feature",
206+
is_array=True,
207+
array_shape="-1,2",
208+
)
209+
actual_feature = croissant_builder.datatype_converter(field)
210+
assert isinstance(actual_feature, sequence_feature.Sequence)
211+
assert isinstance(actual_feature.feature, sequence_feature.Sequence)
212+
assert isinstance(actual_feature.feature.feature, text_feature.Text)
213+
214+
164215
def test_sequence_feature_datatype_converter():
165216
field = mlc.Field(
166217
data_types=mlc.DataType.TEXT,

0 commit comments

Comments
 (0)