|
22 | 22 | from tensorflow_datasets.core.dataset_builders import croissant_builder
|
23 | 23 | from tensorflow_datasets.core.features import audio_feature
|
24 | 24 | from tensorflow_datasets.core.features import bounding_boxes
|
| 25 | +from tensorflow_datasets.core.features import bounding_boxes_utils as bb_utils |
25 | 26 | from tensorflow_datasets.core.features import features_dict
|
26 | 27 | from tensorflow_datasets.core.features import image_feature
|
27 | 28 | from tensorflow_datasets.core.features import sequence_feature
|
@@ -130,6 +131,27 @@ def test_simple_datatype_converter(
|
130 | 131 | assert actual_feature == expected_feature
|
131 | 132 |
|
132 | 133 |
|
| 134 | +def test_bbox_datatype_converter(): |
| 135 | + field = mlc.Field( |
| 136 | + data_types=mlc.DataType.BOUNDING_BOX, |
| 137 | + description="Bounding box feature", |
| 138 | + source=mlc.Source(format="XYWH"), |
| 139 | + ) |
| 140 | + actual_feature = croissant_builder.datatype_converter(field) |
| 141 | + assert isinstance(actual_feature, bounding_boxes.BBoxFeature) |
| 142 | + assert actual_feature.bbox_format == bb_utils.BBoxFormat.XYWH |
| 143 | + |
| 144 | + |
| 145 | +def test_bbox_datatype_converter_with_invalid_format(): |
| 146 | + field = mlc.Field( |
| 147 | + data_types=mlc.DataType.BOUNDING_BOX, |
| 148 | + description="Bounding box feature", |
| 149 | + source=mlc.Source(format="InvalidFormat"), |
| 150 | + ) |
| 151 | + with pytest.raises(ValueError, match="Unsupported bounding box format"): |
| 152 | + croissant_builder.datatype_converter(field) |
| 153 | + |
| 154 | + |
133 | 155 | @pytest.mark.parametrize(
|
134 | 156 | ["field", "feature_type", "subfield_types"],
|
135 | 157 | [
|
|
0 commit comments