Skip to content

Commit 21824ce

Browse files
vfdev-5pmeier
andauthored
Port resize tests to pytest and fix their flakyness (#3907)
Co-authored-by: Philip Meier <[email protected]>
1 parent 9c31d1d commit 21824ce

File tree

1 file changed

+77
-72
lines changed

1 file changed

+77
-72
lines changed

test/test_functional_tensor.py

Lines changed: 77 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -324,76 +324,6 @@ def test_pad(self):
324324

325325
self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)
326326

327-
def test_resize(self):
328-
script_fn = torch.jit.script(F.resize)
329-
tensor, pil_img = self._create_data(26, 36, device=self.device)
330-
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
331-
332-
for dt in [None, torch.float32, torch.float64, torch.float16]:
333-
334-
if dt == torch.float16 and torch.device(self.device).type == "cpu":
335-
# skip float16 on CPU case
336-
continue
337-
338-
if dt is not None:
339-
# This is a trivial cast to float of uint8 data to test all cases
340-
tensor = tensor.to(dt)
341-
batch_tensors = batch_tensors.to(dt)
342-
343-
for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]:
344-
for max_size in (None, 33, 40, 1000):
345-
if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
346-
continue # unsupported, see assertRaises below
347-
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
348-
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size)
349-
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size)
350-
351-
assert_equal(
352-
resized_tensor.size()[1:],
353-
resized_pil_img.size[::-1],
354-
msg="{}, {}".format(size, interpolation),
355-
)
356-
357-
if interpolation not in [NEAREST, ]:
358-
# We can not check values if mode = NEAREST, as results are different
359-
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
360-
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
361-
resized_tensor_f = resized_tensor
362-
# we need to cast to uint8 to compare with PIL image
363-
if resized_tensor_f.dtype == torch.uint8:
364-
resized_tensor_f = resized_tensor_f.to(torch.float)
365-
366-
# Pay attention to high tolerance for MAE
367-
self.approxEqualTensorToPIL(
368-
resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation)
369-
)
370-
371-
if isinstance(size, int):
372-
script_size = [size, ]
373-
else:
374-
script_size = size
375-
376-
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation,
377-
max_size=max_size)
378-
assert_equal(resized_tensor, resize_result, msg="{}, {}".format(size, interpolation))
379-
380-
self._test_fn_on_batch(
381-
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size
382-
)
383-
384-
# assert changed type warning
385-
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
386-
res1 = F.resize(tensor, size=32, interpolation=2)
387-
res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
388-
assert_equal(res1, res2)
389-
390-
for img in (tensor, pil_img):
391-
exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
392-
with self.assertRaisesRegex(ValueError, exp_msg):
393-
F.resize(img, size=(32, 34), max_size=35)
394-
with self.assertRaisesRegex(ValueError, "max_size = 32 must be strictly greater"):
395-
F.resize(img, size=32, max_size=32)
396-
397327
def test_resized_crop(self):
398328
# test values of F.resized_crop in several cases:
399329
# 1) resize to the same size, crop to the same size => should be identity
@@ -868,18 +798,93 @@ def test_perspective_interpolation_warning(tester):
868798
tester.assertTrue(res1.equal(res2))
869799

870800

801+
@pytest.mark.parametrize('device', cpu_and_gpu())
802+
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16])
803+
@pytest.mark.parametrize('size', [32, 26, [32, ], [32, 32], (32, 32), [26, 35]])
804+
@pytest.mark.parametrize('max_size', [None, 34, 40, 1000])
805+
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC, NEAREST])
806+
def test_resize(device, dt, size, max_size, interpolation, tester):
807+
808+
if dt == torch.float16 and device == "cpu":
809+
# skip float16 on CPU case
810+
return
811+
812+
if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
813+
return # unsupported
814+
815+
torch.manual_seed(12)
816+
script_fn = torch.jit.script(F.resize)
817+
tensor, pil_img = tester._create_data(26, 36, device=device)
818+
batch_tensors = tester._create_data_batch(16, 18, num_samples=4, device=device)
819+
820+
if dt is not None:
821+
# This is a trivial cast to float of uint8 data to test all cases
822+
tensor = tensor.to(dt)
823+
batch_tensors = batch_tensors.to(dt)
824+
825+
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size)
826+
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size)
827+
828+
assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]
829+
830+
if interpolation not in [NEAREST, ]:
831+
# We can not check values if mode = NEAREST, as results are different
832+
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
833+
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
834+
resized_tensor_f = resized_tensor
835+
# we need to cast to uint8 to compare with PIL image
836+
if resized_tensor_f.dtype == torch.uint8:
837+
resized_tensor_f = resized_tensor_f.to(torch.float)
838+
839+
# Pay attention to high tolerance for MAE
840+
tester.approxEqualTensorToPIL(resized_tensor_f, resized_pil_img, tol=8.0)
841+
842+
if isinstance(size, int):
843+
script_size = [size, ]
844+
else:
845+
script_size = size
846+
847+
resize_result = script_fn(
848+
tensor, size=script_size, interpolation=interpolation, max_size=max_size
849+
)
850+
assert_equal(resized_tensor, resize_result)
851+
852+
tester._test_fn_on_batch(
853+
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size
854+
)
855+
856+
857+
@pytest.mark.parametrize('device', cpu_and_gpu())
858+
def test_resize_asserts(device, tester):
859+
860+
tensor, pil_img = tester._create_data(26, 36, device=device)
861+
862+
# assert changed type warning
863+
with pytest.warns(UserWarning, match=r"Argument interpolation should be of type InterpolationMode"):
864+
res1 = F.resize(tensor, size=32, interpolation=2)
865+
866+
res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
867+
assert_equal(res1, res2)
868+
869+
for img in (tensor, pil_img):
870+
exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
871+
with pytest.raises(ValueError, match=exp_msg):
872+
F.resize(img, size=(32, 34), max_size=35)
873+
with pytest.raises(ValueError, match="max_size = 32 must be strictly greater"):
874+
F.resize(img, size=32, max_size=32)
875+
876+
871877
@pytest.mark.parametrize('device', cpu_and_gpu())
872878
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16])
873879
@pytest.mark.parametrize('size', [[96, 72], [96, 420], [420, 72]])
874880
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC])
875881
def test_resize_antialias(device, dt, size, interpolation, tester):
876882

877-
torch.manual_seed(12)
878-
879883
if dt == torch.float16 and device == "cpu":
880884
# skip float16 on CPU case
881885
return
882886

887+
torch.manual_seed(12)
883888
script_fn = torch.jit.script(F.resize)
884889
tensor, pil_img = tester._create_data(320, 290, device=device)
885890

0 commit comments

Comments
 (0)