37
37
class BBoxFeature (tensor_feature .Tensor ):
38
38
"""`FeatureConnector` for a normalized bounding box.
39
39
40
+ By default, TFDS uses normalized YXYX bbox format. This can be changed by
41
+ passing the `bbox_format` argument, e.g.
42
+ ```
43
+ features=features.FeatureDict({
44
+ 'bbox': tfds.features.BBox(bbox_format=bb_utils.BBoxFormat.XYWH),
45
+ })
46
+ ```
47
+ If you don't know the format of the bbox, you can use `bbox_format=None`. In
48
+ this case, the only check that is done is that 4 floats coordinates are
49
+ provided.
50
+
40
51
Note: If you have multiple bounding boxes, you may want to wrap the feature
41
52
inside a `tfds.features.Sequence`.
42
53
@@ -69,12 +80,17 @@ def __init__(
69
80
self ,
70
81
* ,
71
82
doc : feature_lib .DocArg = None ,
83
+ bbox_format : (
84
+ bb_utils .BBoxFormatType | None
85
+ ) = bb_utils .BBoxFormat .REL_YXYX ,
72
86
):
73
- super (BBoxFeature , self ).__init__ (shape = (4 ,), dtype = np .float32 , doc = doc )
87
+ if isinstance (bbox_format , str ):
88
+ bbox_format = bb_utils .BBoxFormat (bbox_format )
89
+ self .bbox_format = bbox_format
90
+ super ().__init__ (shape = (4 ,), dtype = np .float32 , doc = doc )
74
91
75
92
def encode_example (self , bbox : Union [bb_utils .BBox , np .ndarray ]):
76
93
"""See base class for details."""
77
-
78
94
if isinstance (bbox , np .ndarray ):
79
95
if bbox .shape != (4 ,):
80
96
raise ValueError (
@@ -88,14 +104,22 @@ def encode_example(self, bbox: Union[bb_utils.BBox, np.ndarray]):
88
104
89
105
# Validate the coordinates
90
106
for coordinate in bbox :
91
- if not isinstance (coordinate , (float , np .floating )):
92
- raise ValueError (
93
- 'BBox coordinates should be float. Got {}.' .format (bbox )
94
- )
95
- if not 0.0 <= coordinate <= 1.0 :
96
- raise ValueError (
97
- 'BBox coordinates should be between 0 and 1. Got {}.' .format (bbox )
98
- )
107
+ if (
108
+ self .bbox_format == bb_utils .BBoxFormat .REL_YXYX
109
+ or self .bbox_format == bb_utils .BBoxFormat .REL_XYXY
110
+ ):
111
+ if not isinstance (coordinate , (float , np .floating )):
112
+ raise ValueError (
113
+ 'BBox coordinates should be float. Got {}.' .format (bbox )
114
+ )
115
+ if not 0.0 <= coordinate <= 1.0 :
116
+ raise ValueError (
117
+ 'BBox coordinates should be between 0 and 1. Got {}.' .format (bbox )
118
+ )
119
+ if (
120
+ self .bbox_format == bb_utils .BBoxFormat .YXYX
121
+ or self .bbox_format == bb_utils .BBoxFormat .REL_YXYX
122
+ ):
99
123
if bbox .xmax < bbox .xmin or bbox .ymax < bbox .ymin :
100
124
raise ValueError (
101
125
'BBox coordinates should have min <= max. Got {}.' .format (bbox )
@@ -108,28 +132,51 @@ def encode_example(self, bbox: Union[bb_utils.BBox, np.ndarray]):
108
132
def repr_html (self , ex : np .ndarray ) -> str :
109
133
"""Returns the HTML str representation of an Image with BBoxes."""
110
134
ex = np .expand_dims (ex , axis = 0 ) # Expand single bounding box to batch.
111
- return _repr_html (ex )
135
+ return _repr_html (ex , bbox_format = self . bbox_format )
112
136
113
137
def repr_html_batch (self , ex : np .ndarray ) -> str :
114
138
"""Returns the HTML str representation of an Image with BBoxes (Sequence)."""
115
- return _repr_html (ex )
139
+ return _repr_html (ex , bbox_format = self . bbox_format )
116
140
117
141
@classmethod
118
142
def from_json_content (
119
143
cls , value : Union [Json , feature_pb2 .BoundingBoxFeature ]
120
144
) -> 'BBoxFeature' :
121
- del value # Unused
122
- return cls ()
145
+ if isinstance (value , dict ):
146
+ return cls (** value )
147
+ return cls (
148
+ bbox_format = bb_utils .BBoxFormat (value .bbox_format )
149
+ if value .bbox_format
150
+ else None
151
+ )
123
152
124
- def to_json_content (self ) -> feature_pb2 .BoundingBoxFeature : # pytype: disable=signature-mismatch # overriding-return-type-checks
153
+ def to_json_content (
154
+ self ,
155
+ ) -> (
156
+ feature_pb2 .BoundingBoxFeature
157
+ ): # pytype: disable=signature-mismatch # overriding-return-type-checks
158
+ bbox_format = None
159
+ if self .bbox_format :
160
+ bbox_format = (
161
+ self .bbox_format
162
+ if isinstance (self .bbox_format , str )
163
+ else self .bbox_format .value
164
+ )
125
165
return feature_pb2 .BoundingBoxFeature (
126
166
shape = feature_lib .to_shape_proto (self ._shape ),
127
167
dtype = feature_lib .dtype_to_str (self ._dtype ),
168
+ bbox_format = bbox_format ,
128
169
)
129
170
130
171
131
- def _repr_html (ex : np .ndarray ) -> str :
172
+ def _repr_html (
173
+ ex : np .ndarray , bbox_format : bb_utils .BBoxFormatType | None
174
+ ) -> str :
132
175
"""Returns the HTML str representation of an Image with BBoxes."""
176
+ # If the bbox format is not normalized, we don't draw the bbox on a blank
177
+ # image but we return a string representation of the bbox instead.
178
+ if bbox_format != bb_utils .BBoxFormat .REL_YXYX :
179
+ return repr (ex )
133
180
img = _build_thumbnail_with_bbox (ex )
134
181
img_str = utils .get_base64 (lambda buff : img .save (buff , format = 'PNG' ))
135
182
return f'<img src="data:image/png;base64,{ img_str } " alt="Img" />'
0 commit comments