Skip to content

Commit b96d381

Browse files
NicolasHugpmeier
andauthored
Use torch.testing.assert_close in test_functional_tensor (#3876)
Co-authored-by: Philip Meier <[email protected]>
1 parent 963d432 commit b96d381

File tree

1 file changed

+42
-42
lines changed

1 file changed

+42
-42
lines changed

test/test_functional_tensor.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torchvision.transforms import InterpolationMode
1616

1717
from common_utils import TransformsTester, cpu_and_gpu, needs_cuda
18+
from _assert_utils import assert_equal
1819

1920
from typing import Dict, List, Sequence, Tuple
2021

@@ -39,13 +40,13 @@ def _test_fn_on_batch(self, batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwarg
3940
for i in range(len(batch_tensors)):
4041
img_tensor = batch_tensors[i, ...]
4142
transformed_img = fn(img_tensor, **fn_kwargs)
42-
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]))
43+
assert_equal(transformed_img, transformed_batch[i, ...])
4344

4445
if scripted_fn_atol >= 0:
4546
scripted_fn = torch.jit.script(fn)
4647
# scriptable function test
4748
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
48-
self.assertTrue(transformed_batch.allclose(s_transformed_batch, atol=scripted_fn_atol))
49+
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
4950

5051
def test_assert_image_tensor(self):
5152
shape = (100,)
@@ -79,7 +80,7 @@ def test_vflip(self):
7980

8081
# scriptable function test
8182
vflipped_img_script = script_vflip(img_tensor)
82-
self.assertTrue(vflipped_img.equal(vflipped_img_script))
83+
assert_equal(vflipped_img, vflipped_img_script)
8384

8485
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
8586
self._test_fn_on_batch(batch_tensors, F.vflip)
@@ -94,7 +95,7 @@ def test_hflip(self):
9495

9596
# scriptable function test
9697
hflipped_img_script = script_hflip(img_tensor)
97-
self.assertTrue(hflipped_img.equal(hflipped_img_script))
98+
assert_equal(hflipped_img, hflipped_img_script)
9899

99100
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
100101
self._test_fn_on_batch(batch_tensors, F.hflip)
@@ -140,11 +141,10 @@ def test_hsv2rgb(self):
140141
for h1, s1, v1 in zip(h, s, v):
141142
rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))
142143
colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=self.device)
143-
max_diff = (ft_img - colorsys_img).abs().max()
144-
self.assertLess(max_diff, 1e-5)
144+
torch.testing.assert_close(ft_img, colorsys_img, rtol=0.0, atol=1e-5)
145145

146146
s_rgb_img = scripted_fn(hsv_img)
147-
self.assertTrue(rgb_img.allclose(s_rgb_img))
147+
torch.testing.assert_close(rgb_img, s_rgb_img)
148148

149149
batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
150150
self._test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
@@ -177,7 +177,7 @@ def test_rgb2hsv(self):
177177
self.assertLess(max_diff, 1e-5)
178178

179179
s_hsv_img = scripted_fn(rgb_img)
180-
self.assertTrue(hsv_img.allclose(s_hsv_img, atol=1e-7))
180+
torch.testing.assert_close(hsv_img, s_hsv_img, rtol=1e-5, atol=1e-7)
181181

182182
batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
183183
self._test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
@@ -194,7 +194,7 @@ def test_rgb_to_grayscale(self):
194194
self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")
195195

196196
s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
197-
self.assertTrue(s_gray_tensor.equal(gray_tensor))
197+
assert_equal(s_gray_tensor, gray_tensor)
198198

199199
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
200200
self._test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
@@ -240,12 +240,12 @@ def test_five_crop(self):
240240
for j in range(len(tuple_transformed_imgs)):
241241
true_transformed_img = tuple_transformed_imgs[j]
242242
transformed_img = tuple_transformed_batches[j][i, ...]
243-
self.assertTrue(true_transformed_img.equal(transformed_img))
243+
assert_equal(true_transformed_img, transformed_img)
244244

245245
# scriptable function test
246246
s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11])
247247
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
248-
self.assertTrue(transformed_batch.equal(s_transformed_batch))
248+
assert_equal(transformed_batch, s_transformed_batch)
249249

250250
def test_ten_crop(self):
251251
script_ten_crop = torch.jit.script(F.ten_crop)
@@ -272,12 +272,12 @@ def test_ten_crop(self):
272272
for j in range(len(tuple_transformed_imgs)):
273273
true_transformed_img = tuple_transformed_imgs[j]
274274
transformed_img = tuple_transformed_batches[j][i, ...]
275-
self.assertTrue(true_transformed_img.equal(transformed_img))
275+
assert_equal(true_transformed_img, transformed_img)
276276

277277
# scriptable function test
278278
s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11])
279279
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
280-
self.assertTrue(transformed_batch.equal(s_transformed_batch))
280+
assert_equal(transformed_batch, s_transformed_batch)
281281

282282
def test_pad(self):
283283
script_fn = torch.jit.script(F.pad)
@@ -320,7 +320,7 @@ def test_pad(self):
320320
else:
321321
script_pad = pad
322322
pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
323-
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs))
323+
assert_equal(pad_tensor, pad_tensor_script, msg="{}, {}".format(pad, kwargs))
324324

325325
self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)
326326

@@ -348,9 +348,10 @@ def test_resize(self):
348348
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size)
349349
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size)
350350

351-
self.assertEqual(
352-
resized_tensor.size()[1:], resized_pil_img.size[::-1],
353-
msg="{}, {}".format(size, interpolation)
351+
assert_equal(
352+
resized_tensor.size()[1:],
353+
resized_pil_img.size[::-1],
354+
msg="{}, {}".format(size, interpolation),
354355
)
355356

356357
if interpolation not in [NEAREST, ]:
@@ -374,7 +375,7 @@ def test_resize(self):
374375

375376
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation,
376377
max_size=max_size)
377-
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))
378+
assert_equal(resized_tensor, resize_result, msg="{}, {}".format(size, interpolation))
378379

379380
self._test_fn_on_batch(
380381
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size
@@ -384,7 +385,7 @@ def test_resize(self):
384385
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
385386
res1 = F.resize(tensor, size=32, interpolation=2)
386387
res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
387-
self.assertTrue(res1.equal(res2))
388+
assert_equal(res1, res2)
388389

389390
for img in (tensor, pil_img):
390391
exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
@@ -400,15 +401,17 @@ def test_resized_crop(self):
400401

401402
for mode in [NEAREST, BILINEAR, BICUBIC]:
402403
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode)
403-
self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
404+
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
404405

405406
# 2) resize by half and crop a TL corner
406407
tensor, _ = self._create_data(26, 36, device=self.device)
407408
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST)
408409
expected_out_tensor = tensor[:, :20:2, :30:2]
409-
self.assertTrue(
410-
expected_out_tensor.equal(out_tensor),
411-
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
410+
assert_equal(
411+
expected_out_tensor,
412+
out_tensor,
413+
check_stride=False,
414+
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]),
412415
)
413416

414417
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
@@ -420,15 +423,11 @@ def _test_affine_identity_map(self, tensor, scripted_affine):
420423
# 1) identity map
421424
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
422425

423-
self.assertTrue(
424-
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
425-
)
426+
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
426427
out_tensor = scripted_affine(
427428
tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
428429
)
429-
self.assertTrue(
430-
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
431-
)
430+
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
432431

433432
def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine):
434433
# 2) Test rotation
@@ -452,9 +451,11 @@ def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine):
452451
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
453452
)
454453
if true_tensor is not None:
455-
self.assertTrue(
456-
true_tensor.equal(out_tensor),
457-
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
454+
assert_equal(
455+
true_tensor,
456+
out_tensor,
457+
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5]),
458+
check_stride=False,
458459
)
459460

460461
if out_tensor.dtype != torch.uint8:
@@ -593,18 +594,19 @@ def test_affine(self):
593594
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
594595
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=2)
595596
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
596-
self.assertTrue(res1.equal(res2))
597+
assert_equal(res1, res2)
597598

598599
# assert changed type warning
599600
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
600601
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2)
601602
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
602-
self.assertTrue(res1.equal(res2))
603+
assert_equal(res1, res2)
603604

604605
with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"):
605606
res1 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fillcolor=10)
606607
res2 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fill=10)
607-
self.assertEqual(res1, res2)
608+
# we convert the PIL images to numpy as assert_equal doesn't work on PIL images.
609+
assert_equal(np.asarray(res1), np.asarray(res2))
608610

609611
def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
610612
img_size = pil_img.size
@@ -682,13 +684,13 @@ def test_rotate(self):
682684
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
683685
res1 = F.rotate(tensor, 45, resample=2)
684686
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
685-
self.assertTrue(res1.equal(res2))
687+
assert_equal(res1, res2)
686688

687689
# assert changed type warning
688690
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
689691
res1 = F.rotate(tensor, 45, interpolation=2)
690692
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
691-
self.assertTrue(res1.equal(res2))
693+
assert_equal(res1, res2)
692694

693695
def test_gaussian_blur(self):
694696
small_image_tensor = torch.from_numpy(
@@ -747,10 +749,8 @@ def test_gaussian_blur(self):
747749

748750
for fn in [F.gaussian_blur, scripted_transform]:
749751
out = fn(tensor, kernel_size=ksize, sigma=sigma)
750-
self.assertEqual(true_out.shape, out.shape, msg="{}, {}".format(ksize, sigma))
751-
self.assertLessEqual(
752-
torch.max(true_out.float() - out.float()),
753-
1.0,
752+
torch.testing.assert_close(
753+
out, true_out, rtol=0.0, atol=1.0, check_stride=False,
754754
msg="{}, {}".format(ksize, sigma)
755755
)
756756

@@ -771,7 +771,7 @@ def test_scale_channel(self):
771771
img_chan = torch.randint(0, 256, size=size).to('cpu')
772772
scaled_cpu = F_t._scale_channel(img_chan)
773773
scaled_cuda = F_t._scale_channel(img_chan.to('cuda'))
774-
self.assertTrue(scaled_cpu.equal(scaled_cuda.to('cpu')))
774+
assert_equal(scaled_cpu, scaled_cuda.to('cpu'))
775775

776776

777777
def _get_data_dims_and_points_for_perspective():

0 commit comments

Comments
 (0)