Skip to content

Commit 7b33592

Browse files
author
The TensorFlow Datasets Authors
committed
CroissantBuilder: Add support for fields with subfields in datatype_converter.
PiperOrigin-RevId: 683118823
1 parent e1220a9 commit 7b33592

File tree

3 files changed

+45
-4
lines changed

3 files changed

+45
-4
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161

6262

6363
def datatype_converter(
64-
field,
64+
field: mlc.Field,
6565
int_dtype: Optional[type_utils.TfdsDType] = np.int64,
6666
float_dtype: Optional[type_utils.TfdsDType] = np.float32,
6767
):
@@ -83,8 +83,15 @@ def datatype_converter(
8383
raise NotImplementedError('Not implemented yet.')
8484

8585
field_data_type = field.data_type
86-
8786
if not field_data_type:
87+
# Fields with sub fields are of type None
88+
if field.sub_fields:
89+
return features_dict.FeaturesDict({
90+
subfield.id: datatype_converter(
91+
subfield, int_dtype=int_dtype, float_dtype=float_dtype
92+
)
93+
for subfield in field.sub_fields
94+
})
8895
return None
8996
elif field_data_type == int:
9097
return int_dtype

tensorflow_datasets/core/dataset_builders/croissant_builder_test.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from tensorflow_datasets import testing
2121
from tensorflow_datasets.core import file_adapters
2222
from tensorflow_datasets.core.dataset_builders import croissant_builder
23+
from tensorflow_datasets.core.features import features_dict
2324
from tensorflow_datasets.core.features import image_feature
25+
from tensorflow_datasets.core.features import tensor_feature
2426
from tensorflow_datasets.core.features import text_feature
2527
from tensorflow_datasets.core.utils.lazy_imports_utils import mlcroissant as mlc
2628

@@ -95,28 +97,52 @@ def test_simple_datatype_converter(field, feature_type, int_dtype, float_dtype):
9597

9698

9799
@pytest.mark.parametrize(
98-
["field", "feature_type"],
100+
["field", "feature_type", "subfield_types"],
99101
[
100102
(
101103
mlc.Field(data_types=mlc.DataType.TEXT, description="Text feature"),
102104
text_feature.Text,
105+
None,
103106
),
104107
(
105108
mlc.Field(data_types=mlc.DataType.DATE, description="Date feature"),
106109
text_feature.Text,
110+
None,
107111
),
108112
(
109113
mlc.Field(
110114
data_types=mlc.DataType.IMAGE_OBJECT,
111115
description="Image feature",
112116
),
113117
image_feature.Image,
118+
None,
119+
),
120+
(
121+
mlc.Field(
122+
id="person",
123+
data_types=[],
124+
description="A field with subfields",
125+
sub_fields=[
126+
mlc.Field(id="person/name", data_types=mlc.DataType.TEXT),
127+
mlc.Field(id="person/age", data_types=mlc.DataType.INTEGER),
128+
],
129+
),
130+
features_dict.FeaturesDict,
131+
{
132+
"person/name": text_feature.Text,
133+
"person/age": tensor_feature.Tensor,
134+
},
114135
),
115136
],
116137
)
117-
def test_complex_datatype_converter(field, feature_type):
138+
def test_complex_datatype_converter(field, feature_type, subfield_types):
118139
actual_feature = croissant_builder.datatype_converter(field)
119140
assert isinstance(actual_feature, feature_type)
141+
if subfield_types:
142+
for feature_name in actual_feature.keys():
143+
assert isinstance(
144+
actual_feature[feature_name], subfield_types[feature_name]
145+
)
120146

121147

122148
@pytest.fixture(name="crs_builder")

tensorflow_datasets/core/utils/conversion_utils_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,14 @@ def test_convert_value_raises(value, feature):
140140
),
141141
{'foo': b''},
142142
),
143+
(
144+
{'name': b'Name', 'age': 100},
145+
feature_lib.FeaturesDict({
146+
'name': feature_lib.Text(),
147+
'age': feature_lib.Scalar(dtype=np.int32),
148+
}),
149+
{'name': b'Name', 'age': 100},
150+
),
143151
# nan, but the feature type is not float
144152
(
145153
np.nan,

0 commit comments

Comments
 (0)