Skip to content

Commit 3996daa

Browse files
Add visualization for rotated boxes
1 parent a7d07dc commit 3996daa

File tree

1 file changed

+58
-5
lines changed

1 file changed

+58
-5
lines changed

torchvision/utils.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,57 @@ def norm_range(t, value_range):
123123
return grid
124124

125125

126+
class ImageDrawTV(ImageDraw.ImageDraw):
127+
"""
128+
A wrapper around PIL.ImageDraw to add functionalities for drawing rotated bounding boxes.
129+
"""
130+
131+
def oriented_rectangle(self, xy, fill=None, outline=None, width=1):
132+
self.dashed_line(((xy[0], xy[1]), (xy[2], xy[3])), width=width, fill=outline)
133+
for i in range(2, len(xy), 2):
134+
self.line(
135+
((xy[i], xy[i + 1]), (xy[(i + 2) % len(xy)], xy[(i + 3) % len(xy)])),
136+
width=width,
137+
fill=outline,
138+
)
139+
self.rectangle(xy, fill=fill, outline=None, width=0)
140+
141+
def dashed_line(self, xy, fill=None, width=0, joint=None, dash_length=5, space_length=5):
142+
# Calculate the total length of the line
143+
total_length = 0
144+
for i in range(1, len(xy)):
145+
x1, y1 = xy[i - 1]
146+
x2, y2 = xy[i]
147+
total_length += ((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5
148+
# Initialize the current position and the current dash
149+
current_position = 0
150+
current_dash = True
151+
# Iterate over the coordinates of the line
152+
for i in range(1, len(xy)):
153+
x1, y1 = xy[i - 1]
154+
x2, y2 = xy[i]
155+
# Calculate the length of this segment
156+
segment_length = ((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5
157+
# While there are still dashes to draw on this segment
158+
while segment_length > 0:
159+
# Calculate the length of this dash
160+
dash_length_to_draw = min(segment_length, dash_length if current_dash else space_length)
161+
# Calculate the end point of this dash
162+
dx = x2 - x1
163+
dy = y2 - y1
164+
angle = math.atan2(dy, dx)
165+
end_x = x1 + math.cos(angle) * dash_length_to_draw
166+
end_y = y1 + math.sin(angle) * dash_length_to_draw
167+
# If this is a dash, draw it
168+
if current_dash:
169+
self.line([(x1, y1), (end_x, end_y)], fill, width, joint)
170+
# Update the current position and the current dash
171+
current_position += dash_length_to_draw
172+
segment_length -= dash_length_to_draw
173+
x1, y1 = end_x, end_y
174+
current_dash = not current_dash
175+
176+
126177
@torch.no_grad()
127178
def save_image(
128179
tensor: Union[torch.Tensor, list[torch.Tensor]],
@@ -250,22 +301,24 @@ def draw_bounding_boxes(
250301
img_boxes = boxes.to(torch.int64).tolist()
251302

252303
if fill:
253-
draw = ImageDraw.Draw(img_to_draw, "RGBA")
304+
draw = ImageDrawTV(img_to_draw, "RGBA")
254305
else:
255-
draw = ImageDraw.Draw(img_to_draw)
306+
draw = ImageDrawTV(img_to_draw)
256307

257308
for bbox, color, label, label_color in zip(img_boxes, colors, labels, label_colors): # type: ignore[arg-type]
258309
if fill:
259310
fill_color = color + (100,)
260311
if len(bbox) == 4:
261312
draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
262313
else:
263-
draw.polygon(bbox, width=width, outline=color, fill=fill_color)
314+
# Indicate the orientation of the rotated box with dashed line.
315+
draw.oriented_rectangle(bbox, width=width, outline=color, fill=fill_color)
264316
else:
265317
if len(bbox) == 4:
266318
draw.rectangle(bbox, width=width, outline=color)
267319
else:
268-
draw.polygon(bbox, width=width, outline=color)
320+
# Indicate the orientation of the polygon with dashed line.
321+
draw.oriented_rectangle(bbox, width=width, outline=color)
269322

270323
if label is not None:
271324
box_margin = 1
@@ -433,7 +486,7 @@ def draw_keypoints(
433486

434487
ndarr = image.permute(1, 2, 0).cpu().numpy()
435488
img_to_draw = Image.fromarray(ndarr)
436-
draw = ImageDraw.Draw(img_to_draw)
489+
draw = ImageDrawTV(img_to_draw)
437490
img_kpts = keypoints.to(torch.int64).tolist()
438491
img_vis = visibility.cpu().bool().tolist()
439492

0 commit comments

Comments
 (0)