1010import torch
1111from PIL import __version__ as PILLOW_VERSION_STRING , Image , ImageColor , ImageDraw , ImageFont
1212
13-
1413__all__ = [
1514 "_Image_fromarray" ,
1615 "make_grid" ,
@@ -293,6 +292,7 @@ def draw_bounding_boxes(
293292 font : Optional [str ] = None ,
294293 font_size : Optional [int ] = None ,
295294 label_colors : Optional [Union [list [Union [str , tuple [int , int , int ]]], str , tuple [int , int , int ]]] = None ,
295+ label_background_colors : Optional [Union [list [Union [str , tuple [int , int , int ]]], str , tuple [int , int , int ]]] = None ,
296296 fill_labels : bool = False ,
297297) -> torch .Tensor :
298298 """
@@ -320,7 +320,10 @@ def draw_bounding_boxes(
320320 font_size (int): The requested font size in points.
321321 label_colors (color or list of colors, optional): Colors for the label text. See the description of the
322322 `colors` argument for details. Defaults to the same colors used for the boxes, or to black if ``fill_labels`` is True.
323- fill_labels (bool): If `True` fills the label background with specified box color (from the ``colors`` parameter). Default: False.
323+ label_background_colors (color or list of colors, optional): Colors for the label text box fill. Defaults to the
324+ same colors used for the boxes. Ignored when ``fill_labels`` is False.
325+ fill_labels (bool): If `True` fills the label background with specified color (from the ``label_background_colors`` parameter,
326+ or from the ``colors`` parameter if not specified). Default: False.
324327
325328 Returns:
326329 img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
@@ -362,6 +365,11 @@ def draw_bounding_boxes(
362365 else :
363366 label_colors = colors .copy () # type: ignore[assignment]
364367
368+ if fill_labels and label_background_colors :
369+ label_background_colors = _parse_colors (label_background_colors , num_objects = num_boxes )
370+ else :
371+ label_background_colors = colors .copy () # type: ignore[assignment]
372+
365373 if font is None :
366374 if font_size is not None :
367375 warnings .warn ("Argument 'font_size' will be ignored since 'font' is not set." )
@@ -385,7 +393,7 @@ def draw_bounding_boxes(
385393 else :
386394 draw = _ImageDrawTV (img_to_draw )
387395
388- for bbox , color , label , label_color in zip (img_boxes , colors , labels , label_colors ): # type: ignore[arg-type]
396+ for bbox , color , label , label_color , label_bg_color in zip (img_boxes , colors , labels , label_colors , label_background_colors ): # type: ignore[arg-type]
389397 draw_method = draw .oriented_rectangle if len (bbox ) > 4 else draw .rectangle
390398 fill_color = color + (100 ,) if fill else None
391399 draw_method (bbox , width = width , outline = color , fill = fill_color )
@@ -396,7 +404,7 @@ def draw_bounding_boxes(
396404 if fill_labels :
397405 left , top , right , bottom = draw .textbbox ((bbox [0 ] + margin , bbox [1 ] + margin ), label , font = txt_font )
398406 draw .rectangle (
399- (left - box_margin , top - box_margin , right + box_margin , bottom + box_margin ), fill = color
407+ (left - box_margin , top - box_margin , right + box_margin , bottom + box_margin ), fill = label_bg_color
400408 )
401409 draw .text ((bbox [0 ] + margin , bbox [1 ] + margin ), label , fill = label_color , font = txt_font ) # type: ignore[arg-type]
402410
@@ -545,7 +553,7 @@ def draw_keypoints(
545553 if visibility .shape != keypoints .shape [:- 1 ]:
546554 raise ValueError (
547555 "keypoints and visibility must have the same dimensionality for num_instances and K. "
548- f"Got { visibility .shape = } and { keypoints .shape = } "
556+ f"Got { visibility .shape = } and { keypoints .shape = } "
549557 )
550558
551559 original_dtype = image .dtype
@@ -746,7 +754,7 @@ def _parse_colors(
746754 f"Number of colors must be equal or larger than the number of objects, but got { len (colors )} < { num_objects } ."
747755 )
748756 elif not isinstance (colors , (tuple , str )):
749- raise ValueError (f"` colors` must be a tuple or a string, or a list thereof, but got { colors } ." )
757+ raise ValueError (f"colors must be a tuple or a string, or a list thereof, but got { colors } ." )
750758 elif isinstance (colors , tuple ) and len (colors ) != 3 :
751759 raise ValueError (f"If passed as tuple, colors should be an RGB triplet, but got { colors } ." )
752760 else : # colors specifies a single color for all objects
0 commit comments