Skip to content

Commit f8d6f8f

Browse files
Add label_colors argument to draw_bounding_boxes (#8578)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 19fef3d commit f8d6f8f

File tree

3 files changed

+25
-2
lines changed

3 files changed

+25
-2
lines changed
723 Bytes
Loading

test/test_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,21 @@ def test_draw_boxes():
116116
assert_equal(img, img_cp)
117117

118118

119+
@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
120+
def test_draw_boxes_with_coloured_labels():
121+
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
122+
labels = ["a", "b", "c", "d"]
123+
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
124+
label_colors = ["green", "red", (0, 255, 0), "#FF00FF"]
125+
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True, label_colors=label_colors)
126+
127+
path = os.path.join(
128+
os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_different_label_colors.png"
129+
)
130+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
131+
assert_equal(result, expected)
132+
133+
119134
@pytest.mark.parametrize("fill", [True, False])
120135
def test_draw_boxes_dtypes(fill):
121136
img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8)

torchvision/utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def draw_bounding_boxes(
161161
width: int = 1,
162162
font: Optional[str] = None,
163163
font_size: Optional[int] = None,
164+
label_colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
164165
) -> torch.Tensor:
165166

166167
"""
@@ -184,9 +185,12 @@ def draw_bounding_boxes(
184185
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
185186
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
186187
font_size (int): The requested font size in points.
188+
label_colors (color or list of colors, optional): Colors for the label text. See the description of the
189+
`colors` argument for details. Defaults to the same colors used for the boxes.
187190
188191
Returns:
189192
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
193+
190194
"""
191195
import torchvision.transforms.v2.functional as F # noqa
192196

@@ -219,6 +223,10 @@ def draw_bounding_boxes(
219223
)
220224

221225
colors = _parse_colors(colors, num_objects=num_boxes)
226+
if label_colors:
227+
label_colors = _parse_colors(label_colors, num_objects=num_boxes)
228+
else:
229+
label_colors = colors.copy()
222230

223231
if font is None:
224232
if font_size is not None:
@@ -243,7 +251,7 @@ def draw_bounding_boxes(
243251
else:
244252
draw = ImageDraw.Draw(img_to_draw)
245253

246-
for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type]
254+
for bbox, color, label, label_color in zip(img_boxes, colors, labels, label_colors): # type: ignore[arg-type]
247255
if fill:
248256
fill_color = color + (100,)
249257
draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
@@ -252,7 +260,7 @@ def draw_bounding_boxes(
252260

253261
if label is not None:
254262
margin = width + 1
255-
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
263+
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=label_color, font=txt_font)
256264

257265
out = F.pil_to_tensor(img_to_draw)
258266
if original_dtype.is_floating_point:

0 commit comments

Comments
 (0)