Skip to content

Commit 734aed2

Browse files
Add support for rotated boxes in draw_bounding_boxes
1 parent 966da7e commit 734aed2

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

torchvision/utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,11 @@ def draw_bounding_boxes(
171171
172172
Args:
173173
image (Tensor): Tensor of shape (C, H, W) and dtype uint8 or float.
174-
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
175-
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
176-
`0 <= ymin < ymax < H`.
174+
boxes (Tensor): Tensor of size (N, 4) or (N, 8) containing bounding boxes.
175+
For (N, 4), the format is (xmin, ymin, xmax, ymax) and the boxes are absolute coordinates with respect to the image.
176+
In other words: `0 <= xmin < xmax < W` and `0 <= ymin < ymax < H`.
177+
For (N, 8), the format is (x1, y1, x3, y3, x2, y2, x4, y4) and the boxes are absolute coordinates with respect to the underlying
178+
object, so no need to verify the latter inequalities.
177179
labels (List[str]): List containing the labels of bounding boxes.
178180
colors (color or list of colors, optional): List containing the colors
179181
of the boxes or single color for all boxes. The color can be represented as
@@ -205,7 +207,7 @@ def draw_bounding_boxes(
205207
raise ValueError("Pass individual images, not batches")
206208
elif image.size(0) not in {1, 3}:
207209
raise ValueError("Only grayscale and RGB images are supported")
208-
elif (boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any():
210+
elif boxes.shape[-1] == 4 and ((boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any()):
209211
raise ValueError(
210212
"Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them"
211213
)
@@ -255,9 +257,15 @@ def draw_bounding_boxes(
255257
for bbox, color, label, label_color in zip(img_boxes, colors, labels, label_colors): # type: ignore[arg-type]
256258
if fill:
257259
fill_color = color + (100,)
258-
draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
260+
if len(bbox) == 4:
261+
draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
262+
else:
263+
draw.polygon(bbox, width=width, outline=color, fill=fill_color)
259264
else:
260-
draw.rectangle(bbox, width=width, outline=color)
265+
if len(bbox) == 4:
266+
draw.rectangle(bbox, width=width, outline=color)
267+
else:
268+
draw.polygon(bbox, width=width, outline=color)
261269

262270
if label is not None:
263271
box_margin = 1

0 commit comments

Comments
 (0)