@@ -123,6 +123,57 @@ def norm_range(t, value_range):
123123 return grid
124124
125125
126+ class ImageDrawTV (ImageDraw .ImageDraw ):
127+ """
128+ A wrapper around PIL.ImageDraw to add functionalities for drawing rotated bounding boxes.
129+ """
130+
131+ def oriented_rectangle (self , xy , fill = None , outline = None , width = 1 ):
132+ self .dashed_line (((xy [0 ], xy [1 ]), (xy [2 ], xy [3 ])), width = width , fill = outline )
133+ for i in range (2 , len (xy ), 2 ):
134+ self .line (
135+ ((xy [i ], xy [i + 1 ]), (xy [(i + 2 ) % len (xy )], xy [(i + 3 ) % len (xy )])),
136+ width = width ,
137+ fill = outline ,
138+ )
139+ self .rectangle (xy , fill = fill , outline = None , width = 0 )
140+
141+ def dashed_line (self , xy , fill = None , width = 0 , joint = None , dash_length = 5 , space_length = 5 ):
142+ # Calculate the total length of the line
143+ total_length = 0
144+ for i in range (1 , len (xy )):
145+ x1 , y1 = xy [i - 1 ]
146+ x2 , y2 = xy [i ]
147+ total_length += ((x2 - x1 ) ** 2 + (y2 - y1 ) ** 2 ) ** 0.5
148+ # Initialize the current position and the current dash
149+ current_position = 0
150+ current_dash = True
151+ # Iterate over the coordinates of the line
152+ for i in range (1 , len (xy )):
153+ x1 , y1 = xy [i - 1 ]
154+ x2 , y2 = xy [i ]
155+ # Calculate the length of this segment
156+ segment_length = ((x2 - x1 ) ** 2 + (y2 - y1 ) ** 2 ) ** 0.5
157+ # While there are still dashes to draw on this segment
158+ while segment_length > 0 :
159+ # Calculate the length of this dash
160+ dash_length_to_draw = min (segment_length , dash_length if current_dash else space_length )
161+ # Calculate the end point of this dash
162+ dx = x2 - x1
163+ dy = y2 - y1
164+ angle = math .atan2 (dy , dx )
165+ end_x = x1 + math .cos (angle ) * dash_length_to_draw
166+ end_y = y1 + math .sin (angle ) * dash_length_to_draw
167+ # If this is a dash, draw it
168+ if current_dash :
169+ self .line ([(x1 , y1 ), (end_x , end_y )], fill , width , joint )
170+ # Update the current position and the current dash
171+ current_position += dash_length_to_draw
172+ segment_length -= dash_length_to_draw
173+ x1 , y1 = end_x , end_y
174+ current_dash = not current_dash
175+
176+
126177@torch .no_grad ()
127178def save_image (
128179 tensor : Union [torch .Tensor , list [torch .Tensor ]],
@@ -250,22 +301,24 @@ def draw_bounding_boxes(
250301 img_boxes = boxes .to (torch .int64 ).tolist ()
251302
252303 if fill :
253- draw = ImageDraw . Draw (img_to_draw , "RGBA" )
304+ draw = ImageDrawTV (img_to_draw , "RGBA" )
254305 else :
255- draw = ImageDraw . Draw (img_to_draw )
306+ draw = ImageDrawTV (img_to_draw )
256307
257308 for bbox , color , label , label_color in zip (img_boxes , colors , labels , label_colors ): # type: ignore[arg-type]
258309 if fill :
259310 fill_color = color + (100 ,)
260311 if len (bbox ) == 4 :
261312 draw .rectangle (bbox , width = width , outline = color , fill = fill_color )
262313 else :
263- draw .polygon (bbox , width = width , outline = color , fill = fill_color )
314+ # Indicate the orientation of the rotated box with dashed line.
315+ draw .oriented_rectangle (bbox , width = width , outline = color , fill = fill_color )
264316 else :
265317 if len (bbox ) == 4 :
266318 draw .rectangle (bbox , width = width , outline = color )
267319 else :
268- draw .polygon (bbox , width = width , outline = color )
320+ # Indicate the orientation of the polygon with dashed line.
321+ draw .oriented_rectangle (bbox , width = width , outline = color )
269322
270323 if label is not None :
271324 box_margin = 1
@@ -433,7 +486,7 @@ def draw_keypoints(
433486
434487 ndarr = image .permute (1 , 2 , 0 ).cpu ().numpy ()
435488 img_to_draw = Image .fromarray (ndarr )
436- draw = ImageDraw . Draw (img_to_draw )
489+ draw = ImageDrawTV (img_to_draw )
437490 img_kpts = keypoints .to (torch .int64 ).tolist ()
438491 img_vis = visibility .cpu ().bool ().tolist ()
439492
0 commit comments