@@ -161,6 +161,7 @@ def draw_bounding_boxes(
161161 width : int = 1 ,
162162 font : Optional [str ] = None ,
163163 font_size : Optional [int ] = None ,
164+ label_colors : Optional [Union [List [Union [str , Tuple [int , int , int ]]], str , Tuple [int , int , int ]]] = None ,
164165) -> torch .Tensor :
165166
166167 """
@@ -184,9 +185,12 @@ def draw_bounding_boxes(
184185 also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
185186 `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
186187 font_size (int): The requested font size in points.
188+ label_colors (color or list of colors, optional): Colors for the label text. See the description of the
189+ `colors` argument for details. Defaults to the same colors used for the boxes.
187190
188191 Returns:
189192 img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
193+
190194 """
191195 import torchvision .transforms .v2 .functional as F # noqa
192196
@@ -219,6 +223,10 @@ def draw_bounding_boxes(
219223 )
220224
221225 colors = _parse_colors (colors , num_objects = num_boxes )
226+ if label_colors :
227+ label_colors = _parse_colors (label_colors , num_objects = num_boxes )
228+ else :
229+ label_colors = colors .copy ()
222230
223231 if font is None :
224232 if font_size is not None :
@@ -243,7 +251,7 @@ def draw_bounding_boxes(
243251 else :
244252 draw = ImageDraw .Draw (img_to_draw )
245253
246- for bbox , color , label in zip (img_boxes , colors , labels ): # type: ignore[arg-type]
254+ for bbox , color , label , label_color in zip (img_boxes , colors , labels , label_colors ): # type: ignore[arg-type]
247255 if fill :
248256 fill_color = color + (100 ,)
249257 draw .rectangle (bbox , width = width , outline = color , fill = fill_color )
@@ -252,7 +260,7 @@ def draw_bounding_boxes(
252260
253261 if label is not None :
254262 margin = width + 1
255- draw .text ((bbox [0 ] + margin , bbox [1 ] + margin ), label , fill = color , font = txt_font )
263+ draw .text ((bbox [0 ] + margin , bbox [1 ] + margin ), label , fill = label_color , font = txt_font )
256264
257265 out = F .pil_to_tensor (img_to_draw )
258266 if original_dtype .is_floating_point :
0 commit comments