Skip to content

Commit 1e6682a

Browse files
add label background color parameter
1 parent b208f7f commit 1e6682a

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed
766 Bytes
Loading

test/test_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,30 @@ def test_draw_boxes_with_coloured_label_backgrounds():
164164
)
165165
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
166166
assert_equal(result, expected)
167+
168+
169+
@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
170+
def test_draw_boxes_with_coloured_label_text_boxes():
171+
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
172+
labels = ["a", "b", "c", "d"]
173+
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
174+
label_colors = ["green", "red", (0, 255, 0), "#FF00FF"]
175+
label_background_colors = ["white", "black", "yellow", "blue"]
176+
result = utils.draw_bounding_boxes(
177+
img,
178+
boxes,
179+
labels=labels,
180+
colors=colors,
181+
fill=True,
182+
label_colors=label_colors,
183+
label_background_colors=label_background_colors,
184+
fill_labels=True
185+
)
186+
path = os.path.join(
187+
os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_different_label_background_colors.png"
188+
)
189+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
190+
assert_equal(result, expected)
167191

168192

169193
@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")

torchvision/utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
import numpy as np
1010
import torch
11-
from PIL import __version__ as PILLOW_VERSION_STRING, Image, ImageColor, ImageDraw, ImageFont
12-
11+
from PIL import Image, ImageColor, ImageDraw, ImageFont
12+
from PIL import __version__ as PILLOW_VERSION_STRING
1313

1414
__all__ = [
1515
"_Image_fromarray",
@@ -293,6 +293,7 @@ def draw_bounding_boxes(
293293
font: Optional[str] = None,
294294
font_size: Optional[int] = None,
295295
label_colors: Optional[Union[list[Union[str, tuple[int, int, int]]], str, tuple[int, int, int]]] = None,
296+
label_background_colors: Optional[Union[list[Union[str, tuple[int, int, int]]], str, tuple[int, int, int]]] = None,
296297
fill_labels: bool = False,
297298
) -> torch.Tensor:
298299
"""
@@ -320,6 +321,8 @@ def draw_bounding_boxes(
320321
font_size (int): The requested font size in points.
321322
label_colors (color or list of colors, optional): Colors for the label text. See the description of the
322323
`colors` argument for details. Defaults to the same colors used for the boxes, or to black if ``fill_labels`` is True.
324+
label_background_colors (color or list of colors, optional): Colors for the label text box fill. Defaults to the
325+
same colors used for the boxes. Ignored when ``fill_labels`` is False.
323326
fill_labels (bool): If `True` fills the label background with specified box color (from the ``colors`` parameter). Default: False.
324327
325328
Returns:
@@ -362,6 +365,11 @@ def draw_bounding_boxes(
362365
else:
363366
label_colors = colors.copy() # type: ignore[assignment]
364367

368+
if fill_labels:
369+
label_background_colors = _parse_colors(label_background_colors, num_objects=num_boxes) if label_background_colors else colors.copy() # type: ignore[assignment]
370+
else:
371+
label_background_colors = colors.copy() # type: ignore[assignment]
372+
365373
if font is None:
366374
if font_size is not None:
367375
warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.")
@@ -385,7 +393,7 @@ def draw_bounding_boxes(
385393
else:
386394
draw = _ImageDrawTV(img_to_draw)
387395

388-
for bbox, color, label, label_color in zip(img_boxes, colors, labels, label_colors): # type: ignore[arg-type]
396+
for bbox, color, label, label_color, label_bg_color in zip(img_boxes, colors, labels, label_colors, label_background_colors): # type: ignore[arg-type]
389397
draw_method = draw.oriented_rectangle if len(bbox) > 4 else draw.rectangle
390398
fill_color = color + (100,) if fill else None
391399
draw_method(bbox, width=width, outline=color, fill=fill_color)
@@ -396,7 +404,7 @@ def draw_bounding_boxes(
396404
if fill_labels:
397405
left, top, right, bottom = draw.textbbox((bbox[0] + margin, bbox[1] + margin), label, font=txt_font)
398406
draw.rectangle(
399-
(left - box_margin, top - box_margin, right + box_margin, bottom + box_margin), fill=color
407+
(left - box_margin, top - box_margin, right + box_margin, bottom + box_margin), fill=label_bg_color
400408
)
401409
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=label_color, font=txt_font) # type: ignore[arg-type]
402410

0 commit comments

Comments
 (0)