Skip to content

Commit 40d6b37

Browse files
author
The TensorFlow Datasets Authors
committed
Add support for bounding boxes in croissant_builder.
PiperOrigin-RevId: 686902619
1 parent c388d5f commit 40d6b37

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from tensorflow_datasets.core import download
4747
from tensorflow_datasets.core import split_builder as split_builder_lib
4848
from tensorflow_datasets.core import splits as splits_lib
49+
from tensorflow_datasets.core.features import bounding_boxes
4950
from tensorflow_datasets.core.features import feature as feature_lib
5051
from tensorflow_datasets.core.features import features_dict
5152
from tensorflow_datasets.core.features import image_feature
@@ -86,12 +87,15 @@ def datatype_converter(
8687
if not field_data_type:
8788
# Fields with sub fields are of type None
8889
if field.sub_fields:
89-
return features_dict.FeaturesDict({
90-
subfield.id: datatype_converter(
91-
subfield, int_dtype=int_dtype, float_dtype=float_dtype
92-
)
93-
for subfield in field.sub_fields
94-
})
90+
return features_dict.FeaturesDict(
91+
{
92+
subfield.id: datatype_converter(
93+
subfield, int_dtype=int_dtype, float_dtype=float_dtype
94+
)
95+
for subfield in field.sub_fields
96+
},
97+
doc=field.description,
98+
)
9599
return None
96100
elif field_data_type == int:
97101
return int_dtype
@@ -106,6 +110,9 @@ def datatype_converter(
106110
return text_feature.Text(doc=field.description)
107111
elif field_data_type == mlc.DataType.IMAGE_OBJECT:
108112
return image_feature.Image(doc=field.description)
113+
elif field_data_type == mlc.DataType.BOUNDING_BOX:
114+
# TFDS uses REL_YXYX by default, but Hugging Face doesn't enforce a format.
115+
return bounding_boxes.BBoxFeature(doc=field.description, bbox_format=None)
109116
else:
110117
raise ValueError(f'Unknown data type: {field_data_type}.')
111118

@@ -225,7 +232,9 @@ def __init__(
225232
@property
226233
def builder_config(self) -> dataset_builder.BuilderConfig:
227234
"""`tfds.core.BuilderConfig` for this builder."""
228-
return self._builder_config # pytype: disable=bad-return-type # always-use-return-annotations
235+
return (
236+
self._builder_config
237+
) # pytype: disable=bad-return-type # always-use-return-annotations
229238

230239
def _info(self) -> dataset_info.DatasetInfo:
231240
return dataset_info.DatasetInfo(

tensorflow_datasets/core/dataset_builders/croissant_builder_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from tensorflow_datasets import testing
2121
from tensorflow_datasets.core import file_adapters
2222
from tensorflow_datasets.core.dataset_builders import croissant_builder
23+
from tensorflow_datasets.core.features import bounding_boxes
2324
from tensorflow_datasets.core.features import features_dict
2425
from tensorflow_datasets.core.features import image_feature
2526
from tensorflow_datasets.core.features import tensor_feature
@@ -117,6 +118,14 @@ def test_simple_datatype_converter(field, feature_type, int_dtype, float_dtype):
117118
image_feature.Image,
118119
None,
119120
),
121+
(
122+
mlc.Field(
123+
data_types=mlc.DataType.BOUNDING_BOX,
124+
description="Bounding box feature",
125+
),
126+
bounding_boxes.BBoxFeature,
127+
None,
128+
),
120129
(
121130
mlc.Field(
122131
id="person",
@@ -138,6 +147,7 @@ def test_simple_datatype_converter(field, feature_type, int_dtype, float_dtype):
138147
def test_complex_datatype_converter(field, feature_type, subfield_types):
139148
actual_feature = croissant_builder.datatype_converter(field)
140149
assert isinstance(actual_feature, feature_type)
150+
assert actual_feature.doc.desc == field.description
141151
if subfield_types:
142152
for feature_name in actual_feature.keys():
143153
assert isinstance(

0 commit comments

Comments
 (0)