Skip to content

Commit 154283b

Browse files
authored
Port test/test_utils.py to pytest (#3917)
1 parent 1b6fe68 commit 154283b

File tree

1 file changed

+124
-115
lines changed

1 file changed

+124
-115
lines changed

test/test_utils.py

Lines changed: 124 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import tempfile
66
import torch
77
import torchvision.utils as utils
8-
import unittest
8+
99
from io import BytesIO
1010
import torchvision.transforms.functional as F
1111
from PIL import Image, __version__ as PILLOW_VERSION, ImageColor
@@ -18,122 +18,131 @@
1818
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
1919

2020

21-
class Tester(unittest.TestCase):
22-
23-
def test_make_grid_not_inplace(self):
24-
t = torch.rand(5, 3, 10, 10)
25-
t_clone = t.clone()
26-
27-
utils.make_grid(t, normalize=False)
28-
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
29-
30-
utils.make_grid(t, normalize=True, scale_each=False)
31-
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
32-
33-
utils.make_grid(t, normalize=True, scale_each=True)
34-
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
35-
36-
def test_normalize_in_make_grid(self):
37-
t = torch.rand(5, 3, 10, 10) * 255
38-
norm_max = torch.tensor(1.0)
39-
norm_min = torch.tensor(0.0)
40-
41-
grid = utils.make_grid(t, normalize=True)
42-
grid_max = torch.max(grid)
43-
grid_min = torch.min(grid)
44-
45-
# Rounding the result to one decimal for comparison
46-
n_digits = 1
47-
rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
48-
rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)
49-
50-
assert_equal(norm_max, rounded_grid_max, msg='Normalized max is not equal to 1')
51-
assert_equal(norm_min, rounded_grid_min, msg='Normalized min is not equal to 0')
52-
53-
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
54-
def test_save_image(self):
55-
with tempfile.NamedTemporaryFile(suffix='.png') as f:
56-
t = torch.rand(2, 3, 64, 64)
57-
utils.save_image(t, f.name)
58-
self.assertTrue(os.path.exists(f.name), 'The image is not present after save')
59-
60-
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
61-
def test_save_image_single_pixel(self):
62-
with tempfile.NamedTemporaryFile(suffix='.png') as f:
63-
t = torch.rand(1, 3, 1, 1)
64-
utils.save_image(t, f.name)
65-
self.assertTrue(os.path.exists(f.name), 'The pixel image is not present after save')
66-
67-
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
68-
def test_save_image_file_object(self):
69-
with tempfile.NamedTemporaryFile(suffix='.png') as f:
70-
t = torch.rand(2, 3, 64, 64)
71-
utils.save_image(t, f.name)
72-
img_orig = Image.open(f.name)
73-
fp = BytesIO()
74-
utils.save_image(t, fp, format='png')
75-
img_bytes = Image.open(fp)
76-
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')
77-
78-
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
79-
def test_save_image_single_pixel_file_object(self):
80-
with tempfile.NamedTemporaryFile(suffix='.png') as f:
81-
t = torch.rand(1, 3, 1, 1)
82-
utils.save_image(t, f.name)
83-
img_orig = Image.open(f.name)
84-
fp = BytesIO()
85-
utils.save_image(t, fp, format='png')
86-
img_bytes = Image.open(fp)
87-
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')
88-
89-
def test_draw_boxes(self):
90-
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
91-
img_cp = img.clone()
92-
boxes_cp = boxes.clone()
93-
labels = ["a", "b", "c", "d"]
94-
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
95-
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)
96-
97-
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
98-
if not os.path.exists(path):
99-
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
100-
res.save(path)
101-
102-
if PILLOW_VERSION >= (8, 2):
103-
# The reference image is only valid for new PIL versions
104-
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
105-
assert_equal(result, expected)
106-
107-
# Check if modification is not in place
108-
assert_equal(boxes, boxes_cp)
109-
assert_equal(img, img_cp)
110-
111-
def test_draw_boxes_vanilla(self):
112-
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
113-
img_cp = img.clone()
114-
boxes_cp = boxes.clone()
115-
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)
116-
117-
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
118-
if not os.path.exists(path):
119-
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
120-
res.save(path)
21+
def test_make_grid_not_inplace():
22+
t = torch.rand(5, 3, 10, 10)
23+
t_clone = t.clone()
24+
25+
utils.make_grid(t, normalize=False)
26+
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
27+
28+
utils.make_grid(t, normalize=True, scale_each=False)
29+
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
30+
31+
utils.make_grid(t, normalize=True, scale_each=True)
32+
assert_equal(t, t_clone, msg='make_grid modified tensor in-place')
33+
34+
35+
def test_normalize_in_make_grid():
36+
t = torch.rand(5, 3, 10, 10) * 255
37+
norm_max = torch.tensor(1.0)
38+
norm_min = torch.tensor(0.0)
39+
40+
grid = utils.make_grid(t, normalize=True)
41+
grid_max = torch.max(grid)
42+
grid_min = torch.min(grid)
43+
44+
# Rounding the result to one decimal for comparison
45+
n_digits = 1
46+
rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
47+
rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)
48+
49+
assert_equal(norm_max, rounded_grid_max, msg='Normalized max is not equal to 1')
50+
assert_equal(norm_min, rounded_grid_min, msg='Normalized min is not equal to 0')
51+
52+
53+
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
54+
def test_save_image():
55+
with tempfile.NamedTemporaryFile(suffix='.png') as f:
56+
t = torch.rand(2, 3, 64, 64)
57+
utils.save_image(t, f.name)
58+
assert os.path.exists(f.name), 'The image is not present after save'
12159

60+
61+
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
62+
def test_save_image_single_pixel():
63+
with tempfile.NamedTemporaryFile(suffix='.png') as f:
64+
t = torch.rand(1, 3, 1, 1)
65+
utils.save_image(t, f.name)
66+
assert os.path.exists(f.name), 'The pixel image is not present after save'
67+
68+
69+
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
70+
def test_save_image_file_object():
71+
with tempfile.NamedTemporaryFile(suffix='.png') as f:
72+
t = torch.rand(2, 3, 64, 64)
73+
utils.save_image(t, f.name)
74+
img_orig = Image.open(f.name)
75+
fp = BytesIO()
76+
utils.save_image(t, fp, format='png')
77+
img_bytes = Image.open(fp)
78+
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')
79+
80+
81+
@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows')
82+
def test_save_image_single_pixel_file_object():
83+
with tempfile.NamedTemporaryFile(suffix='.png') as f:
84+
t = torch.rand(1, 3, 1, 1)
85+
utils.save_image(t, f.name)
86+
img_orig = Image.open(f.name)
87+
fp = BytesIO()
88+
utils.save_image(t, fp, format='png')
89+
img_bytes = Image.open(fp)
90+
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object')
91+
92+
93+
def test_draw_boxes():
94+
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
95+
img_cp = img.clone()
96+
boxes_cp = boxes.clone()
97+
labels = ["a", "b", "c", "d"]
98+
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
99+
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)
100+
101+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
102+
if not os.path.exists(path):
103+
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
104+
res.save(path)
105+
106+
if PILLOW_VERSION >= (8, 2):
107+
# The reference image is only valid for new PIL versions
122108
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
123109
assert_equal(result, expected)
124-
# Check if modification is not in place
125-
assert_equal(boxes, boxes_cp)
126-
assert_equal(img, img_cp)
127110

128-
def test_draw_invalid_boxes(self):
129-
img_tp = ((1, 1, 1), (1, 2, 3))
130-
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
131-
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
132-
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
133-
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
134-
self.assertRaises(TypeError, utils.draw_bounding_boxes, img_tp, boxes)
135-
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong1, boxes)
136-
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong2, boxes)
111+
# Check if modification is not in place
112+
assert_equal(boxes, boxes_cp)
113+
assert_equal(img, img_cp)
114+
115+
116+
def test_draw_boxes_vanilla():
117+
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
118+
img_cp = img.clone()
119+
boxes_cp = boxes.clone()
120+
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)
121+
122+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
123+
if not os.path.exists(path):
124+
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
125+
res.save(path)
126+
127+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
128+
assert_equal(result, expected)
129+
# Check if modification is not in place
130+
assert_equal(boxes, boxes_cp)
131+
assert_equal(img, img_cp)
132+
133+
134+
def test_draw_invalid_boxes():
135+
img_tp = ((1, 1, 1), (1, 2, 3))
136+
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
137+
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
138+
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
139+
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
140+
with pytest.raises(TypeError, match="Tensor expected"):
141+
utils.draw_bounding_boxes(img_tp, boxes)
142+
with pytest.raises(ValueError, match="Tensor uint8 expected"):
143+
utils.draw_bounding_boxes(img_wrong1, boxes)
144+
with pytest.raises(ValueError, match="Pass individual images, not batches"):
145+
utils.draw_bounding_boxes(img_wrong2, boxes)
137146

138147

139148
@pytest.mark.parametrize('colors', [
@@ -218,5 +227,5 @@ def test_draw_segmentation_masks_errors():
218227
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
219228

220229

221-
if __name__ == '__main__':
222-
unittest.main()
230+
if __name__ == "__main__":
231+
pytest.main([__file__])

0 commit comments

Comments
 (0)