@@ -171,9 +171,11 @@ def draw_bounding_boxes(
171171
172172 Args:
173173 image (Tensor): Tensor of shape (C, H, W) and dtype uint8 or float.
174- boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
175- the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
176- `0 <= ymin < ymax < H`.
174+ boxes (Tensor): Tensor of size (N, 4) or (N, 8) containing bounding boxes.
175+ For (N, 4), the format is (xmin, ymin, xmax, ymax) and the boxes are absolute coordinates with respect to the image.
176+ In other words: `0 <= xmin < xmax < W` and `0 <= ymin < ymax < H`.
177+ For (N, 8), the format is (x1, y1, x3, y3, x2, y2, x4, y4) and the boxes are absolute coordinates with respect to the underlying
178+ object, so no need to verify the latter inequalities.
177179 labels (List[str]): List containing the labels of bounding boxes.
178180 colors (color or list of colors, optional): List containing the colors
179181 of the boxes or single color for all boxes. The color can be represented as
@@ -205,7 +207,7 @@ def draw_bounding_boxes(
205207 raise ValueError ("Pass individual images, not batches" )
206208 elif image .size (0 ) not in {1 , 3 }:
207209 raise ValueError ("Only grayscale and RGB images are supported" )
208- elif ( boxes [:, 0 ] > boxes [:, 2 ]).any () or (boxes [:, 1 ] > boxes [:, 3 ]).any ():
210+ elif boxes . shape [ - 1 ] == 4 and (( boxes [:, 0 ] > boxes [:, 2 ]).any () or (boxes [:, 1 ] > boxes [:, 3 ]).any () ):
209211 raise ValueError (
210212 "Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them"
211213 )
@@ -255,9 +257,15 @@ def draw_bounding_boxes(
255257 for bbox , color , label , label_color in zip (img_boxes , colors , labels , label_colors ): # type: ignore[arg-type]
256258 if fill :
257259 fill_color = color + (100 ,)
258- draw .rectangle (bbox , width = width , outline = color , fill = fill_color )
260+ if len (bbox ) == 4 :
261+ draw .rectangle (bbox , width = width , outline = color , fill = fill_color )
262+ else :
263+ draw .polygon (bbox , width = width , outline = color , fill = fill_color )
259264 else :
260- draw .rectangle (bbox , width = width , outline = color )
265+ if len (bbox ) == 4 :
266+ draw .rectangle (bbox , width = width , outline = color )
267+ else :
268+ draw .polygon (bbox , width = width , outline = color )
261269
262270 if label is not None :
263271 box_margin = 1
0 commit comments