|
5 | 5 | import torchvision.transforms as transforms
|
6 | 6 | import torchvision.transforms.functional as F
|
7 | 7 | from torch._utils_internal import get_file_path_2
|
| 8 | +from numpy.testing import assert_array_almost_equal |
8 | 9 | import unittest
|
9 | 10 | import math
|
10 | 11 | import random
|
@@ -843,6 +844,25 @@ def test_normalize_different_dtype(self):
|
843 | 844 | # checks that it doesn't crash
|
844 | 845 | transforms.functional.normalize(img, mean, std)
|
845 | 846 |
|
| 847 | + def test_normalize_3d_tensor(self): |
| 848 | + torch.manual_seed(28) |
| 849 | + n_channels = 3 |
| 850 | + img_size = 10 |
| 851 | + mean = torch.rand(n_channels) |
| 852 | + std = torch.rand(n_channels) |
| 853 | + img = torch.rand(n_channels, img_size, img_size) |
| 854 | + target = F.normalize(img, mean, std).numpy() |
| 855 | + |
| 856 | + mean_unsqueezed = mean.view(-1, 1, 1) |
| 857 | + std_unsqueezed = std.view(-1, 1, 1) |
| 858 | + result1 = F.normalize(img, mean_unsqueezed, std_unsqueezed) |
| 859 | + result2 = F.normalize(img, |
| 860 | + mean_unsqueezed.repeat(1, img_size, img_size), |
| 861 | + std_unsqueezed.repeat(1, img_size, img_size)) |
| 862 | + assert_array_almost_equal(target, result1.numpy()) |
| 863 | + assert_array_almost_equal(target, result2.numpy()) |
| 864 | + |
| 865 | + |
846 | 866 | def test_adjust_brightness(self):
|
847 | 867 | x_shape = [2, 2, 3]
|
848 | 868 | x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
|
|
0 commit comments