Skip to content

Commit e61b68e

Browse files
authored
F.normalize unsqueeze mean & std only for 1-d arrays (#2002)
* F.normalize unsqueeze mean & std if necessary * added tests to F.normalize for 3d mean & std tensors
1 parent ae228fe commit e61b68e

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

test/test_transforms.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torchvision.transforms as transforms
66
import torchvision.transforms.functional as F
77
from torch._utils_internal import get_file_path_2
8+
from numpy.testing import assert_array_almost_equal
89
import unittest
910
import math
1011
import random
@@ -843,6 +844,25 @@ def test_normalize_different_dtype(self):
843844
# checks that it doesn't crash
844845
transforms.functional.normalize(img, mean, std)
845846

847+
def test_normalize_3d_tensor(self):
848+
torch.manual_seed(28)
849+
n_channels = 3
850+
img_size = 10
851+
mean = torch.rand(n_channels)
852+
std = torch.rand(n_channels)
853+
img = torch.rand(n_channels, img_size, img_size)
854+
target = F.normalize(img, mean, std).numpy()
855+
856+
mean_unsqueezed = mean.view(-1, 1, 1)
857+
std_unsqueezed = std.view(-1, 1, 1)
858+
result1 = F.normalize(img, mean_unsqueezed, std_unsqueezed)
859+
result2 = F.normalize(img,
860+
mean_unsqueezed.repeat(1, img_size, img_size),
861+
std_unsqueezed.repeat(1, img_size, img_size))
862+
assert_array_almost_equal(target, result1.numpy())
863+
assert_array_almost_equal(target, result2.numpy())
864+
865+
846866
def test_adjust_brightness(self):
847867
x_shape = [2, 2, 3]
848868
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]

torchvision/transforms/functional.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,11 @@ def normalize(tensor, mean, std, inplace=False):
211211
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
212212
if (std == 0).any():
213213
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
214-
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
214+
if mean.ndim == 1:
215+
mean = mean[:, None, None]
216+
if std.ndim == 1:
217+
std = std[:, None, None]
218+
tensor.sub_(mean).div_(std)
215219
return tensor
216220

217221

0 commit comments

Comments
 (0)