@@ -116,6 +116,21 @@ def test_draw_boxes():
116116 assert_equal (img , img_cp )
117117
118118
119+ @pytest .mark .skipif (PILLOW_VERSION < (10 , 1 ), reason = "The reference image is only valid for PIL >= 10.1" )
120+ def test_draw_boxes_with_coloured_labels ():
121+ img = torch .full ((3 , 100 , 100 ), 255 , dtype = torch .uint8 )
122+ labels = ["a" , "b" , "c" , "d" ]
123+ colors = ["green" , "#FF00FF" , (0 , 255 , 0 ), "red" ]
124+ label_colors = ["green" , "red" , (0 , 255 , 0 ), "#FF00FF" ]
125+ result = utils .draw_bounding_boxes (img , boxes , labels = labels , colors = colors , fill = True , label_colors = label_colors )
126+
127+ path = os .path .join (
128+ os .path .dirname (os .path .abspath (__file__ )), "assets" , "fakedata" , "draw_boxes_different_label_colors.png"
129+ )
130+ expected = torch .as_tensor (np .array (Image .open (path ))).permute (2 , 0 , 1 )
131+ assert_equal (result , expected )
132+
133+
119134@pytest .mark .parametrize ("fill" , [True , False ])
120135def test_draw_boxes_dtypes (fill ):
121136 img_uint8 = torch .full ((3 , 100 , 100 ), 255 , dtype = torch .uint8 )
0 commit comments