Skip to content

Commit 6223d8c

Browse files
author
The TensorFlow Datasets Authors
committed
Update the CroissantBuilder to use more precise data types.
PiperOrigin-RevId: 733689187
1 parent 4a4d3c7 commit 6223d8c

File tree

2 files changed

+45
-27
lines changed

2 files changed

+45
-27
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import json
4141
from typing import Any
4242

43+
from etils import enp
4344
from etils import epath
4445
import numpy as np
4546
from tensorflow_datasets.core import dataset_builder
@@ -79,8 +80,7 @@ def _strip_record_set_prefix(
7980
def array_datatype_converter(
8081
feature: type_utils.TfdsDType | feature_lib.FeatureConnector | None,
8182
field: mlc.Field,
82-
int_dtype: type_utils.TfdsDType = np.int64,
83-
float_dtype: type_utils.TfdsDType = np.float32,
83+
dtype_mapping: Mapping[type_utils.TfdsDType, type_utils.TfdsDType],
8484
):
8585
"""Includes the given feature in a sequence or tensor feature.
8686
@@ -91,32 +91,28 @@ def array_datatype_converter(
9191
Args:
9292
feature: The inner feature to include in a sequence or tensor feature.
9393
field: The mlc.Field object.
94-
int_dtype: The dtype to use for TFDS integer features. Defaults to np.int64.
95-
float_dtype: The dtype to use for TFDS float features. Defaults to
96-
np.float32.
94+
dtype_mapping: A mapping of dtypes to the corresponding dtypes that will be
95+
used in TFDS.
9796
9897
Returns:
9998
A sequence or tensor feature including the inner feature.
10099
"""
101-
dtype_mapping = {
102-
int: int_dtype,
103-
float: float_dtype,
104-
bool: np.bool_,
105-
bytes: np.str_,
106-
}
107-
dtype = dtype_mapping.get(field.data_type, None)
100+
field_dtype = None
101+
if field.data_type in dtype_mapping:
102+
field_dtype = dtype_mapping[field.data_type]
103+
elif enp.lazy.is_np_dtype(field.data_type):
104+
field_dtype = field.data_type
105+
108106
if len(field.array_shape_tuple) == 1:
109107
return sequence_feature.Sequence(feature, doc=field.description)
110-
elif (-1 in field.array_shape_tuple) or (
111-
field.data_type not in dtype_mapping
112-
):
108+
elif (-1 in field.array_shape_tuple) or (field_dtype is None):
113109
for _ in range(len(field.array_shape_tuple)):
114110
feature = sequence_feature.Sequence(feature, doc=field.description)
115111
return feature
116112
else:
117113
return tensor_feature.Tensor(
118114
shape=field.array_shape_tuple,
119-
dtype=dtype,
115+
dtype=field_dtype,
120116
doc=field.description,
121117
)
122118

@@ -142,8 +138,15 @@ def datatype_converter(
142138
"""
143139
if field.is_enumeration:
144140
raise NotImplementedError('Not implemented yet.')
141+
dtype_mapping = {
142+
bool: np.bool_,
143+
bytes: np.str_,
144+
float: float_dtype,
145+
int: int_dtype,
146+
}
145147

146148
field_data_type = field.data_type
149+
147150
if not field_data_type:
148151
# Fields with sub fields are of type None
149152
if field.sub_fields:
@@ -158,14 +161,12 @@ def datatype_converter(
158161
)
159162
else:
160163
feature = None
161-
elif field_data_type == int:
162-
feature = int_dtype
163-
elif field_data_type == float:
164-
feature = float_dtype
165-
elif field_data_type == bool:
166-
feature = np.bool_
167164
elif field_data_type == bytes:
168165
feature = text_feature.Text(doc=field.description)
166+
elif field_data_type in dtype_mapping:
167+
feature = dtype_mapping[field_data_type]
168+
elif enp.lazy.is_np_dtype(field_data_type):
169+
feature = field_data_type
169170
# We return a text feature for mlc.DataType.DATE features.
170171
elif field_data_type == pd.Timestamp:
171172
feature = text_feature.Text(doc=field.description)
@@ -183,8 +184,7 @@ def datatype_converter(
183184
feature = array_datatype_converter(
184185
feature=feature,
185186
field=field,
186-
int_dtype=int_dtype,
187-
float_dtype=float_dtype,
187+
dtype_mapping=dtype_mapping,
188188
)
189189
# If the field is repeated, we return a sequence feature. `field.repeated` is
190190
# deprecated starting from Croissant 1.1, but we still support it for

tensorflow_datasets/core/dataset_builders/croissant_builder_test.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050

5151

5252
@pytest.mark.parametrize(
53-
["field", "feature_type", "int_dtype", "float_dtype"],
53+
["field", "expected_feature", "int_dtype", "float_dtype"],
5454
[
5555
(
5656
mlc.Field(
@@ -60,6 +60,14 @@
6060
None,
6161
None,
6262
),
63+
(
64+
mlc.Field(
65+
data_types=mlc.DataType.INT16, description="Int16 feature"
66+
),
67+
np.int16,
68+
None,
69+
None,
70+
),
6371
(
6472
mlc.Field(
6573
data_types=mlc.DataType.INTEGER, description="Integer feature"
@@ -76,6 +84,14 @@
7684
None,
7785
None,
7886
),
87+
(
88+
mlc.Field(
89+
data_types=mlc.DataType.FLOAT16, description="Float16 feature"
90+
),
91+
np.float16,
92+
None,
93+
None,
94+
),
7995
(
8096
mlc.Field(
8197
data_types=mlc.DataType.FLOAT, description="Float feature"
@@ -94,13 +110,15 @@
94110
),
95111
],
96112
)
97-
def test_simple_datatype_converter(field, feature_type, int_dtype, float_dtype):
113+
def test_simple_datatype_converter(
114+
field, expected_feature, int_dtype, float_dtype
115+
):
98116
actual_feature = croissant_builder.datatype_converter(
99117
field,
100118
int_dtype=int_dtype or np.int64,
101119
float_dtype=float_dtype or np.float32,
102120
)
103-
assert actual_feature == feature_type
121+
assert actual_feature == expected_feature
104122

105123

106124
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)