15
15
from torchvision .transforms import InterpolationMode
16
16
17
17
from common_utils import TransformsTester , cpu_and_gpu , needs_cuda
18
+ from _assert_utils import assert_equal
18
19
19
20
from typing import Dict , List , Sequence , Tuple
20
21
@@ -39,13 +40,13 @@ def _test_fn_on_batch(self, batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwarg
39
40
for i in range (len (batch_tensors )):
40
41
img_tensor = batch_tensors [i , ...]
41
42
transformed_img = fn (img_tensor , ** fn_kwargs )
42
- self . assertTrue (transformed_img . equal ( transformed_batch [i , ...]) )
43
+ assert_equal (transformed_img , transformed_batch [i , ...])
43
44
44
45
if scripted_fn_atol >= 0 :
45
46
scripted_fn = torch .jit .script (fn )
46
47
# scriptable function test
47
48
s_transformed_batch = scripted_fn (batch_tensors , ** fn_kwargs )
48
- self . assertTrue ( transformed_batch . allclose ( s_transformed_batch , atol = scripted_fn_atol ) )
49
+ torch . testing . assert_close ( transformed_batch , s_transformed_batch , rtol = 1e-5 , atol = scripted_fn_atol )
49
50
50
51
def test_assert_image_tensor (self ):
51
52
shape = (100 ,)
@@ -79,7 +80,7 @@ def test_vflip(self):
79
80
80
81
# scriptable function test
81
82
vflipped_img_script = script_vflip (img_tensor )
82
- self . assertTrue (vflipped_img . equal ( vflipped_img_script ) )
83
+ assert_equal (vflipped_img , vflipped_img_script )
83
84
84
85
batch_tensors = self ._create_data_batch (16 , 18 , num_samples = 4 , device = self .device )
85
86
self ._test_fn_on_batch (batch_tensors , F .vflip )
@@ -94,7 +95,7 @@ def test_hflip(self):
94
95
95
96
# scriptable function test
96
97
hflipped_img_script = script_hflip (img_tensor )
97
- self . assertTrue (hflipped_img . equal ( hflipped_img_script ) )
98
+ assert_equal (hflipped_img , hflipped_img_script )
98
99
99
100
batch_tensors = self ._create_data_batch (16 , 18 , num_samples = 4 , device = self .device )
100
101
self ._test_fn_on_batch (batch_tensors , F .hflip )
@@ -140,11 +141,10 @@ def test_hsv2rgb(self):
140
141
for h1 , s1 , v1 in zip (h , s , v ):
141
142
rgb .append (colorsys .hsv_to_rgb (h1 , s1 , v1 ))
142
143
colorsys_img = torch .tensor (rgb , dtype = torch .float32 , device = self .device )
143
- max_diff = (ft_img - colorsys_img ).abs ().max ()
144
- self .assertLess (max_diff , 1e-5 )
144
+ torch .testing .assert_close (ft_img , colorsys_img , rtol = 0.0 , atol = 1e-5 )
145
145
146
146
s_rgb_img = scripted_fn (hsv_img )
147
- self . assertTrue ( rgb_img . allclose ( s_rgb_img ) )
147
+ torch . testing . assert_close ( rgb_img , s_rgb_img )
148
148
149
149
batch_tensors = self ._create_data_batch (120 , 100 , num_samples = 4 , device = self .device ).float ()
150
150
self ._test_fn_on_batch (batch_tensors , F_t ._hsv2rgb )
@@ -177,7 +177,7 @@ def test_rgb2hsv(self):
177
177
self .assertLess (max_diff , 1e-5 )
178
178
179
179
s_hsv_img = scripted_fn (rgb_img )
180
- self . assertTrue ( hsv_img . allclose ( s_hsv_img , atol = 1e-7 ) )
180
+ torch . testing . assert_close ( hsv_img , s_hsv_img , rtol = 1e-5 , atol = 1e-7 )
181
181
182
182
batch_tensors = self ._create_data_batch (120 , 100 , num_samples = 4 , device = self .device ).float ()
183
183
self ._test_fn_on_batch (batch_tensors , F_t ._rgb2hsv )
@@ -194,7 +194,7 @@ def test_rgb_to_grayscale(self):
194
194
self .approxEqualTensorToPIL (gray_tensor .float (), gray_pil_image , tol = 1.0 + 1e-10 , agg_method = "max" )
195
195
196
196
s_gray_tensor = script_rgb_to_grayscale (img_tensor , num_output_channels = num_output_channels )
197
- self . assertTrue (s_gray_tensor . equal ( gray_tensor ) )
197
+ assert_equal (s_gray_tensor , gray_tensor )
198
198
199
199
batch_tensors = self ._create_data_batch (16 , 18 , num_samples = 4 , device = self .device )
200
200
self ._test_fn_on_batch (batch_tensors , F .rgb_to_grayscale , num_output_channels = num_output_channels )
@@ -240,12 +240,12 @@ def test_five_crop(self):
240
240
for j in range (len (tuple_transformed_imgs )):
241
241
true_transformed_img = tuple_transformed_imgs [j ]
242
242
transformed_img = tuple_transformed_batches [j ][i , ...]
243
- self . assertTrue (true_transformed_img . equal ( transformed_img ) )
243
+ assert_equal (true_transformed_img , transformed_img )
244
244
245
245
# scriptable function test
246
246
s_tuple_transformed_batches = script_five_crop (batch_tensors , [10 , 11 ])
247
247
for transformed_batch , s_transformed_batch in zip (tuple_transformed_batches , s_tuple_transformed_batches ):
248
- self . assertTrue (transformed_batch . equal ( s_transformed_batch ) )
248
+ assert_equal (transformed_batch , s_transformed_batch )
249
249
250
250
def test_ten_crop (self ):
251
251
script_ten_crop = torch .jit .script (F .ten_crop )
@@ -272,12 +272,12 @@ def test_ten_crop(self):
272
272
for j in range (len (tuple_transformed_imgs )):
273
273
true_transformed_img = tuple_transformed_imgs [j ]
274
274
transformed_img = tuple_transformed_batches [j ][i , ...]
275
- self . assertTrue (true_transformed_img . equal ( transformed_img ) )
275
+ assert_equal (true_transformed_img , transformed_img )
276
276
277
277
# scriptable function test
278
278
s_tuple_transformed_batches = script_ten_crop (batch_tensors , [10 , 11 ])
279
279
for transformed_batch , s_transformed_batch in zip (tuple_transformed_batches , s_tuple_transformed_batches ):
280
- self . assertTrue (transformed_batch . equal ( s_transformed_batch ) )
280
+ assert_equal (transformed_batch , s_transformed_batch )
281
281
282
282
def test_pad (self ):
283
283
script_fn = torch .jit .script (F .pad )
@@ -320,7 +320,7 @@ def test_pad(self):
320
320
else :
321
321
script_pad = pad
322
322
pad_tensor_script = script_fn (tensor , script_pad , ** kwargs )
323
- self . assertTrue (pad_tensor . equal ( pad_tensor_script ) , msg = "{}, {}" .format (pad , kwargs ))
323
+ assert_equal (pad_tensor , pad_tensor_script , msg = "{}, {}" .format (pad , kwargs ))
324
324
325
325
self ._test_fn_on_batch (batch_tensors , F .pad , padding = script_pad , ** kwargs )
326
326
@@ -348,9 +348,10 @@ def test_resize(self):
348
348
resized_tensor = F .resize (tensor , size = size , interpolation = interpolation , max_size = max_size )
349
349
resized_pil_img = F .resize (pil_img , size = size , interpolation = interpolation , max_size = max_size )
350
350
351
- self .assertEqual (
352
- resized_tensor .size ()[1 :], resized_pil_img .size [::- 1 ],
353
- msg = "{}, {}" .format (size , interpolation )
351
+ assert_equal (
352
+ resized_tensor .size ()[1 :],
353
+ resized_pil_img .size [::- 1 ],
354
+ msg = "{}, {}" .format (size , interpolation ),
354
355
)
355
356
356
357
if interpolation not in [NEAREST , ]:
@@ -374,7 +375,7 @@ def test_resize(self):
374
375
375
376
resize_result = script_fn (tensor , size = script_size , interpolation = interpolation ,
376
377
max_size = max_size )
377
- self . assertTrue (resized_tensor . equal ( resize_result ) , msg = "{}, {}" .format (size , interpolation ))
378
+ assert_equal (resized_tensor , resize_result , msg = "{}, {}" .format (size , interpolation ))
378
379
379
380
self ._test_fn_on_batch (
380
381
batch_tensors , F .resize , size = script_size , interpolation = interpolation , max_size = max_size
@@ -384,7 +385,7 @@ def test_resize(self):
384
385
with self .assertWarnsRegex (UserWarning , r"Argument interpolation should be of type InterpolationMode" ):
385
386
res1 = F .resize (tensor , size = 32 , interpolation = 2 )
386
387
res2 = F .resize (tensor , size = 32 , interpolation = BILINEAR )
387
- self . assertTrue (res1 . equal ( res2 ) )
388
+ assert_equal (res1 , res2 )
388
389
389
390
for img in (tensor , pil_img ):
390
391
exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
@@ -400,15 +401,17 @@ def test_resized_crop(self):
400
401
401
402
for mode in [NEAREST , BILINEAR , BICUBIC ]:
402
403
out_tensor = F .resized_crop (tensor , top = 0 , left = 0 , height = 26 , width = 36 , size = [26 , 36 ], interpolation = mode )
403
- self . assertTrue (tensor . equal ( out_tensor ) , msg = "{} vs {}" .format (out_tensor [0 , :5 , :5 ], tensor [0 , :5 , :5 ]))
404
+ assert_equal (tensor , out_tensor , msg = "{} vs {}" .format (out_tensor [0 , :5 , :5 ], tensor [0 , :5 , :5 ]))
404
405
405
406
# 2) resize by half and crop a TL corner
406
407
tensor , _ = self ._create_data (26 , 36 , device = self .device )
407
408
out_tensor = F .resized_crop (tensor , top = 0 , left = 0 , height = 20 , width = 30 , size = [10 , 15 ], interpolation = NEAREST )
408
409
expected_out_tensor = tensor [:, :20 :2 , :30 :2 ]
409
- self .assertTrue (
410
- expected_out_tensor .equal (out_tensor ),
411
- msg = "{} vs {}" .format (expected_out_tensor [0 , :10 , :10 ], out_tensor [0 , :10 , :10 ])
410
+ assert_equal (
411
+ expected_out_tensor ,
412
+ out_tensor ,
413
+ check_stride = False ,
414
+ msg = "{} vs {}" .format (expected_out_tensor [0 , :10 , :10 ], out_tensor [0 , :10 , :10 ]),
412
415
)
413
416
414
417
batch_tensors = self ._create_data_batch (26 , 36 , num_samples = 4 , device = self .device )
@@ -420,15 +423,11 @@ def _test_affine_identity_map(self, tensor, scripted_affine):
420
423
# 1) identity map
421
424
out_tensor = F .affine (tensor , angle = 0 , translate = [0 , 0 ], scale = 1.0 , shear = [0.0 , 0.0 ], interpolation = NEAREST )
422
425
423
- self .assertTrue (
424
- tensor .equal (out_tensor ), msg = "{} vs {}" .format (out_tensor [0 , :5 , :5 ], tensor [0 , :5 , :5 ])
425
- )
426
+ assert_equal (tensor , out_tensor , msg = "{} vs {}" .format (out_tensor [0 , :5 , :5 ], tensor [0 , :5 , :5 ]))
426
427
out_tensor = scripted_affine (
427
428
tensor , angle = 0 , translate = [0 , 0 ], scale = 1.0 , shear = [0.0 , 0.0 ], interpolation = NEAREST
428
429
)
429
- self .assertTrue (
430
- tensor .equal (out_tensor ), msg = "{} vs {}" .format (out_tensor [0 , :5 , :5 ], tensor [0 , :5 , :5 ])
431
- )
430
+ assert_equal (tensor , out_tensor , msg = "{} vs {}" .format (out_tensor [0 , :5 , :5 ], tensor [0 , :5 , :5 ]))
432
431
433
432
def _test_affine_square_rotations (self , tensor , pil_img , scripted_affine ):
434
433
# 2) Test rotation
@@ -452,9 +451,11 @@ def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine):
452
451
tensor , angle = a , translate = [0 , 0 ], scale = 1.0 , shear = [0.0 , 0.0 ], interpolation = NEAREST
453
452
)
454
453
if true_tensor is not None :
455
- self .assertTrue (
456
- true_tensor .equal (out_tensor ),
457
- msg = "{}\n {} vs \n {}" .format (a , out_tensor [0 , :5 , :5 ], true_tensor [0 , :5 , :5 ])
454
+ assert_equal (
455
+ true_tensor ,
456
+ out_tensor ,
457
+ msg = "{}\n {} vs \n {}" .format (a , out_tensor [0 , :5 , :5 ], true_tensor [0 , :5 , :5 ]),
458
+ check_stride = False ,
458
459
)
459
460
460
461
if out_tensor .dtype != torch .uint8 :
@@ -593,18 +594,19 @@ def test_affine(self):
593
594
with self .assertWarnsRegex (UserWarning , r"Argument resample is deprecated and will be removed" ):
594
595
res1 = F .affine (tensor , 45 , translate = [0 , 0 ], scale = 1.0 , shear = [0.0 , 0.0 ], resample = 2 )
595
596
res2 = F .affine (tensor , 45 , translate = [0 , 0 ], scale = 1.0 , shear = [0.0 , 0.0 ], interpolation = BILINEAR )
596
- self . assertTrue (res1 . equal ( res2 ) )
597
+ assert_equal (res1 , res2 )
597
598
598
599
# assert changed type warning
599
600
with self .assertWarnsRegex (UserWarning , r"Argument interpolation should be of type InterpolationMode" ):
600
601
res1 = F .affine (tensor , 45 , translate = [0 , 0 ], scale = 1.0 , shear = [0.0 , 0.0 ], interpolation = 2 )
601
602
res2 = F .affine (tensor , 45 , translate = [0 , 0 ], scale = 1.0 , shear = [0.0 , 0.0 ], interpolation = BILINEAR )
602
- self . assertTrue (res1 . equal ( res2 ) )
603
+ assert_equal (res1 , res2 )
603
604
604
605
with self .assertWarnsRegex (UserWarning , r"Argument fillcolor is deprecated and will be removed" ):
605
606
res1 = F .affine (pil_img , 45 , translate = [0 , 0 ], scale = 1.0 , shear = [0.0 , 0.0 ], fillcolor = 10 )
606
607
res2 = F .affine (pil_img , 45 , translate = [0 , 0 ], scale = 1.0 , shear = [0.0 , 0.0 ], fill = 10 )
607
- self .assertEqual (res1 , res2 )
608
+ # we convert the PIL images to numpy as assert_equal doesn't work on PIL images.
609
+ assert_equal (np .asarray (res1 ), np .asarray (res2 ))
608
610
609
611
def _test_rotate_all_options (self , tensor , pil_img , scripted_rotate , centers ):
610
612
img_size = pil_img .size
@@ -682,13 +684,13 @@ def test_rotate(self):
682
684
with self .assertWarnsRegex (UserWarning , r"Argument resample is deprecated and will be removed" ):
683
685
res1 = F .rotate (tensor , 45 , resample = 2 )
684
686
res2 = F .rotate (tensor , 45 , interpolation = BILINEAR )
685
- self . assertTrue (res1 . equal ( res2 ) )
687
+ assert_equal (res1 , res2 )
686
688
687
689
# assert changed type warning
688
690
with self .assertWarnsRegex (UserWarning , r"Argument interpolation should be of type InterpolationMode" ):
689
691
res1 = F .rotate (tensor , 45 , interpolation = 2 )
690
692
res2 = F .rotate (tensor , 45 , interpolation = BILINEAR )
691
- self . assertTrue (res1 . equal ( res2 ) )
693
+ assert_equal (res1 , res2 )
692
694
693
695
def test_gaussian_blur (self ):
694
696
small_image_tensor = torch .from_numpy (
@@ -747,10 +749,8 @@ def test_gaussian_blur(self):
747
749
748
750
for fn in [F .gaussian_blur , scripted_transform ]:
749
751
out = fn (tensor , kernel_size = ksize , sigma = sigma )
750
- self .assertEqual (true_out .shape , out .shape , msg = "{}, {}" .format (ksize , sigma ))
751
- self .assertLessEqual (
752
- torch .max (true_out .float () - out .float ()),
753
- 1.0 ,
752
+ torch .testing .assert_close (
753
+ out , true_out , rtol = 0.0 , atol = 1.0 , check_stride = False ,
754
754
msg = "{}, {}" .format (ksize , sigma )
755
755
)
756
756
@@ -771,7 +771,7 @@ def test_scale_channel(self):
771
771
img_chan = torch .randint (0 , 256 , size = size ).to ('cpu' )
772
772
scaled_cpu = F_t ._scale_channel (img_chan )
773
773
scaled_cuda = F_t ._scale_channel (img_chan .to ('cuda' ))
774
- self . assertTrue (scaled_cpu . equal ( scaled_cuda .to ('cpu' ) ))
774
+ assert_equal (scaled_cpu , scaled_cuda .to ('cpu' ))
775
775
776
776
777
777
def _get_data_dims_and_points_for_perspective ():
0 commit comments