Skip to content

Commit 9f9e320

Browse files
Add tests for label background drawing
1 parent 2e72bbb commit 9f9e320

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed
680 Bytes
Loading

test/test_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,25 @@ def test_draw_boxes_with_coloured_labels():
131131
assert_equal(result, expected)
132132

133133

134+
@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
135+
def test_draw_boxes_with_coloured_label_backgrounds():
136+
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
137+
labels = ["a", "b", "c", "d"]
138+
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
139+
label_colors = ["green", "red", (0, 255, 0), "#FF00FF"]
140+
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True, label_colors=label_colors, fill_labels=True)
141+
# utils.save_image(
142+
# result.div(255),
143+
# "/home/antoinesimoulin/vision/test/assets/fakedata/draw_boxes_different_label_fill_colors.png",
144+
# )
145+
146+
path = os.path.join(
147+
os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_different_label_fill_colors.png"
148+
)
149+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
150+
assert_equal(result, expected)
151+
152+
134153
@pytest.mark.parametrize("fill", [True, False])
135154
def test_draw_boxes_dtypes(fill):
136155
img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8)

0 commit comments

Comments
 (0)