Skip to content

Commit baaa7af

Browse files
author
The TensorFlow Datasets Authors
committed
Allow for bboxes in other formats.
PiperOrigin-RevId: 686530258
1 parent 1538910 commit baaa7af

File tree

6 files changed

+120
-29
lines changed

6 files changed

+120
-29
lines changed

tensorflow_datasets/core/features/bounding_boxes.py

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@
3737
class BBoxFeature(tensor_feature.Tensor):
3838
"""`FeatureConnector` for a normalized bounding box.
3939
40+
By default, TFDS uses normalized YXYX bbox format. This can be changed by
41+
passing the `bbox_format` argument, e.g.
42+
```
43+
features=features.FeatureDict({
44+
'bbox': tfds.features.BBox(bbox_format=bb_utils.BBoxFormat.XYWH),
45+
})
46+
```
47+
If you don't know the format of the bbox, you can use `bbox_format=None`. In
48+
this case, the only check that is done is that 4 floats coordinates are
49+
provided.
50+
4051
Note: If you have multiple bounding boxes, you may want to wrap the feature
4152
inside a `tfds.features.Sequence`.
4253
@@ -69,12 +80,17 @@ def __init__(
6980
self,
7081
*,
7182
doc: feature_lib.DocArg = None,
83+
bbox_format: (
84+
bb_utils.BBoxFormatType | None
85+
) = bb_utils.BBoxFormat.REL_YXYX,
7286
):
73-
super(BBoxFeature, self).__init__(shape=(4,), dtype=np.float32, doc=doc)
87+
if isinstance(bbox_format, str):
88+
bbox_format = bb_utils.BBoxFormat(bbox_format)
89+
self.bbox_format = bbox_format
90+
super().__init__(shape=(4,), dtype=np.float32, doc=doc)
7491

7592
def encode_example(self, bbox: Union[bb_utils.BBox, np.ndarray]):
7693
"""See base class for details."""
77-
7894
if isinstance(bbox, np.ndarray):
7995
if bbox.shape != (4,):
8096
raise ValueError(
@@ -88,14 +104,22 @@ def encode_example(self, bbox: Union[bb_utils.BBox, np.ndarray]):
88104

89105
# Validate the coordinates
90106
for coordinate in bbox:
91-
if not isinstance(coordinate, (float, np.floating)):
92-
raise ValueError(
93-
'BBox coordinates should be float. Got {}.'.format(bbox)
94-
)
95-
if not 0.0 <= coordinate <= 1.0:
96-
raise ValueError(
97-
'BBox coordinates should be between 0 and 1. Got {}.'.format(bbox)
98-
)
107+
if (
108+
self.bbox_format == bb_utils.BBoxFormat.REL_YXYX
109+
or self.bbox_format == bb_utils.BBoxFormat.REL_XYXY
110+
):
111+
if not isinstance(coordinate, (float, np.floating)):
112+
raise ValueError(
113+
'BBox coordinates should be float. Got {}.'.format(bbox)
114+
)
115+
if not 0.0 <= coordinate <= 1.0:
116+
raise ValueError(
117+
'BBox coordinates should be between 0 and 1. Got {}.'.format(bbox)
118+
)
119+
if (
120+
self.bbox_format == bb_utils.BBoxFormat.YXYX
121+
or self.bbox_format == bb_utils.BBoxFormat.REL_YXYX
122+
):
99123
if bbox.xmax < bbox.xmin or bbox.ymax < bbox.ymin:
100124
raise ValueError(
101125
'BBox coordinates should have min <= max. Got {}.'.format(bbox)
@@ -108,28 +132,51 @@ def encode_example(self, bbox: Union[bb_utils.BBox, np.ndarray]):
108132
def repr_html(self, ex: np.ndarray) -> str:
109133
"""Returns the HTML str representation of an Image with BBoxes."""
110134
ex = np.expand_dims(ex, axis=0) # Expand single bounding box to batch.
111-
return _repr_html(ex)
135+
return _repr_html(ex, bbox_format=self.bbox_format)
112136

113137
def repr_html_batch(self, ex: np.ndarray) -> str:
114138
"""Returns the HTML str representation of an Image with BBoxes (Sequence)."""
115-
return _repr_html(ex)
139+
return _repr_html(ex, bbox_format=self.bbox_format)
116140

117141
@classmethod
118142
def from_json_content(
119143
cls, value: Union[Json, feature_pb2.BoundingBoxFeature]
120144
) -> 'BBoxFeature':
121-
del value # Unused
122-
return cls()
145+
if isinstance(value, dict):
146+
return cls(**value)
147+
return cls(
148+
bbox_format=bb_utils.BBoxFormat(value.bbox_format)
149+
if value.bbox_format
150+
else None
151+
)
123152

124-
def to_json_content(self) -> feature_pb2.BoundingBoxFeature: # pytype: disable=signature-mismatch # overriding-return-type-checks
153+
def to_json_content(
154+
self,
155+
) -> (
156+
feature_pb2.BoundingBoxFeature
157+
): # pytype: disable=signature-mismatch # overriding-return-type-checks
158+
bbox_format = None
159+
if self.bbox_format:
160+
bbox_format = (
161+
self.bbox_format
162+
if isinstance(self.bbox_format, str)
163+
else self.bbox_format.value
164+
)
125165
return feature_pb2.BoundingBoxFeature(
126166
shape=feature_lib.to_shape_proto(self._shape),
127167
dtype=feature_lib.dtype_to_str(self._dtype),
168+
bbox_format=bbox_format,
128169
)
129170

130171

131-
def _repr_html(ex: np.ndarray) -> str:
172+
def _repr_html(
173+
ex: np.ndarray, bbox_format: bb_utils.BBoxFormatType | None
174+
) -> str:
132175
"""Returns the HTML str representation of an Image with BBoxes."""
176+
# If the bbox format is not normalized, we don't draw the bbox on a blank
177+
# image but we return a string representation of the bbox instead.
178+
if bbox_format != bb_utils.BBoxFormat.REL_YXYX:
179+
return repr(ex)
133180
img = _build_thumbnail_with_bbox(ex)
134181
img_str = utils.get_base64(lambda buff: img.save(buff, format='PNG'))
135182
return f'<img src="data:image/png;base64,{img_str}" alt="Img" />'

tensorflow_datasets/core/features/bounding_boxes_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,24 @@ def test_feature(self):
5151
],
5252
)
5353

54+
def test_unspecified_bbox_format_feature(self):
55+
self.assertFeature(
56+
feature=features.BBoxFeature(bbox_format=None),
57+
shape=(4,),
58+
dtype=np.float32,
59+
tests=[
60+
testing.FeatureExpectationItem(
61+
# 1D numpy array, float32 unnormalized bbox
62+
value=np.array(
63+
[200.46, 199.84, 77.71, 70.88], dtype=np.float32
64+
),
65+
expected=np.array(
66+
[200.46, 199.84, 77.71, 70.88], dtype=np.float32
67+
),
68+
),
69+
],
70+
)
71+
5472

5573
if __name__ == '__main__':
5674
testing.test_main()

tensorflow_datasets/core/features/bounding_boxes_utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class BBoxFormat(enum.Enum):
4646
XYXY = 'XYXY'
4747
YXYX = 'YXYX'
4848
XYWH = 'XYWH'
49+
REL_XYXY = 'REL_XYXY'
50+
REL_YXYX = 'REL_YXYX'
4951

5052

5153
BBoxFormatType = Union[BBoxFormat, str]
@@ -77,7 +79,6 @@ def convert_coordinates_to_bbox(
7779
"""
7880
if len(coordinates) != 4:
7981
raise ValueError(f'Expected 4 coordinates, got {coordinates}.')
80-
8182
coordinates = coordinates.astype(np.float64)
8283

8384
try:
@@ -88,15 +89,24 @@ def convert_coordinates_to_bbox(
8889
f'Unsupported bbox format: {format}. Currently supported bounding box'
8990
f' formats are: {[format.value for format in BBoxFormat]}'
9091
) from e
91-
92-
if input_format == BBoxFormat.YXYX:
92+
if (
93+
input_format == BBoxFormat.REL_XYXY or input_format == BBoxFormat.REL_YXYX
94+
) and normalize:
95+
raise ValueError(
96+
'If the input format is normalized, then normalize should be False.'
97+
)
98+
if normalize and img_shape is None:
99+
raise ValueError(
100+
'If normalize is True, img_shape must be provided, but got None.'
101+
)
102+
if input_format == BBoxFormat.YXYX or input_format == BBoxFormat.REL_YXYX:
93103
bbox = BBox(
94104
ymin=coordinates[0],
95105
xmin=coordinates[1],
96106
ymax=coordinates[2],
97107
xmax=coordinates[3],
98108
)
99-
elif input_format == BBoxFormat.XYXY:
109+
elif input_format == BBoxFormat.XYXY or input_format == BBoxFormat.REL_XYXY:
100110
bbox = BBox(
101111
ymin=coordinates[1],
102112
xmin=coordinates[0],

tensorflow_datasets/core/features/bounding_boxes_utils_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,19 @@ def test_convert_to_bbox(
6767
[
6868
([1, 2, 3], None, 'Expected 4'),
6969
(TEST_INPUT_LIST, 'NonExistentConverter', 'Unsupported bbox format'),
70+
(
71+
TEST_INPUT_LIST,
72+
bb_utils.BBoxFormat.XYXY,
73+
'If normalize is True, img_shape must be provided, but got None.',
74+
),
75+
(
76+
TEST_INPUT_LIST,
77+
bb_utils.BBoxFormat.REL_XYXY,
78+
(
79+
'If the input format is normalized, then normalize should be'
80+
' False.'
81+
),
82+
),
7083
],
7184
)
7285
def test_convert_coordinates_to_bbox_valueerror(

tensorflow_datasets/core/proto/feature.proto

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,12 @@ message AudioFeature {
105105

106106
// A bounding box around an object in an image. Typically, bounding boxes are
107107
// tensors of type `tf.float32` and shape `[4,]` and contain the normalized
108-
// coordinates of the bounding box `[ymin, xmin, ymax, xmax]`.
108+
// coordinates of the bounding box `[ymin, xmin, ymax, xmax]`. Different formats
109+
// are supported through the `bbox_format` field.
109110
message BoundingBoxFeature {
110111
Shape shape = 1;
111112
string dtype = 2;
113+
string bbox_format = 3;
112114
}
113115

114116
message TextFeature {}

tensorflow_datasets/core/proto/feature_generated_pb2.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,9 @@
7272
b' \x01(\t\x12\x13\n\x0bsample_rate\x18\x04'
7373
b' \x01(\x03\x12\x10\n\x08\x65ncoding\x18\x05'
7474
b' \x01(\t\x12\x13\n\x0blazy_decode\x18\x06'
75-
b' \x01(\x08"N\n\x12\x42oundingBoxFeature\x12)\n\x05shape\x18\x01'
75+
b' \x01(\x08"c\n\x12\x42oundingBoxFeature\x12)\n\x05shape\x18\x01'
7676
b' \x01(\x0b\x32\x1a.tensorflow_datasets.Shape\x12\r\n\x05\x64type\x18\x02'
77+
b' \x01(\t\x12\x13\n\x0b\x62\x62ox_format\x18\x03'
7778
b' \x01(\t"\r\n\x0bTextFeature"O\n\x12TranslationFeature\x12\x11\n\tlanguages\x18\x01'
7879
b' \x03(\t\x12&\n\x1evariable_languages_per_example\x18\x02'
7980
b' \x01(\x08"I\n\x08Sequence\x12-\n\x07\x66\x65\x61ture\x18\x01'
@@ -110,11 +111,11 @@
110111
_AUDIOFEATURE._serialized_start = 1500
111112
_AUDIOFEATURE._serialized_end = 1653
112113
_BOUNDINGBOXFEATURE._serialized_start = 1655
113-
_BOUNDINGBOXFEATURE._serialized_end = 1733
114-
_TEXTFEATURE._serialized_start = 1735
115-
_TEXTFEATURE._serialized_end = 1748
116-
_TRANSLATIONFEATURE._serialized_start = 1750
117-
_TRANSLATIONFEATURE._serialized_end = 1829
118-
_SEQUENCE._serialized_start = 1831
119-
_SEQUENCE._serialized_end = 1904
114+
_BOUNDINGBOXFEATURE._serialized_end = 1754
115+
_TEXTFEATURE._serialized_start = 1756
116+
_TEXTFEATURE._serialized_end = 1769
117+
_TRANSLATIONFEATURE._serialized_start = 1771
118+
_TRANSLATIONFEATURE._serialized_end = 1850
119+
_SEQUENCE._serialized_start = 1852
120+
_SEQUENCE._serialized_end = 1925
120121
# @@protoc_insertion_point(module_scope)

0 commit comments

Comments
 (0)