Skip to content

Commit 45b9cfc

Browse files
author
The TensorFlow Datasets Authors
committed
CroissantBuilder: Get the bounding box format from the field's source.
PiperOrigin-RevId: 784167034
1 parent f3cf375 commit 45b9cfc

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from tensorflow_datasets.core import splits as splits_lib
5151
from tensorflow_datasets.core.features import audio_feature
5252
from tensorflow_datasets.core.features import bounding_boxes
53+
from tensorflow_datasets.core.features import bounding_boxes_utils as bb_utils
5354
from tensorflow_datasets.core.features import feature as feature_lib
5455
from tensorflow_datasets.core.features import features_dict
5556
from tensorflow_datasets.core.features import image_feature
@@ -175,8 +176,17 @@ def datatype_converter(
175176
feature = image_feature.Image(doc=field.description)
176177
elif field_data_type == mlc.DataType.BOUNDING_BOX:
177178
# TFDS uses REL_YXYX by default, but Hugging Face doesn't enforce a format.
179+
if bbox_format := field.source.format:
180+
try:
181+
bbox_format = bb_utils.BBoxFormat(bbox_format)
182+
except ValueError as e:
183+
raise ValueError(
184+
f'Unsupported bounding box format: {bbox_format}. Currently'
185+
' supported bounding box formats are: '
186+
f'{[format.value for format in bb_utils.BBoxFormat]}'
187+
) from e
178188
feature = bounding_boxes.BBoxFeature(
179-
doc=field.description, bbox_format=None
189+
doc=field.description, bbox_format=bbox_format
180190
)
181191
elif field_data_type == mlc.DataType.AUDIO_OBJECT:
182192
feature = audio_feature.Audio(

tensorflow_datasets/core/dataset_builders/croissant_builder_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tensorflow_datasets.core.dataset_builders import croissant_builder
2323
from tensorflow_datasets.core.features import audio_feature
2424
from tensorflow_datasets.core.features import bounding_boxes
25+
from tensorflow_datasets.core.features import bounding_boxes_utils as bb_utils
2526
from tensorflow_datasets.core.features import features_dict
2627
from tensorflow_datasets.core.features import image_feature
2728
from tensorflow_datasets.core.features import sequence_feature
@@ -130,6 +131,27 @@ def test_simple_datatype_converter(
130131
assert actual_feature == expected_feature
131132

132133

134+
def test_bbox_datatype_converter():
135+
field = mlc.Field(
136+
data_types=mlc.DataType.BOUNDING_BOX,
137+
description="Bounding box feature",
138+
source=mlc.Source(format="XYWH"),
139+
)
140+
actual_feature = croissant_builder.datatype_converter(field)
141+
assert isinstance(actual_feature, bounding_boxes.BBoxFeature)
142+
assert actual_feature.bbox_format == bb_utils.BBoxFormat.XYWH
143+
144+
145+
def test_bbox_datatype_converter_with_invalid_format():
146+
field = mlc.Field(
147+
data_types=mlc.DataType.BOUNDING_BOX,
148+
description="Bounding box feature",
149+
source=mlc.Source(format="InvalidFormat"),
150+
)
151+
with pytest.raises(ValueError, match="Unsupported bounding box format"):
152+
croissant_builder.datatype_converter(field)
153+
154+
133155
@pytest.mark.parametrize(
134156
["field", "feature_type", "subfield_types"],
135157
[

0 commit comments

Comments
 (0)