Skip to content

Commit 6cbb22b

Browse files
bodokaisersoumith
authored andcommitted
added support for signed ints, removed support for unsigned
1 parent c1a835f commit 6cbb22b

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

test/test_transforms.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,35 @@ def test_ndarray_to_pil_image(self):
169169
l, = img.split()
170170
assert np.allclose(l, img_data[:, :, 0])
171171

172-
def test_ndarray16_to_pil_image(self):
172+
def test_ndarray_bad_types_to_pil_image(self):
173173
trans = transforms.ToPILImage()
174-
img_data = np.random.randint(0, 65535, [4, 4, 1], np.uint16)
174+
with self.assertRaises(AssertionError):
175+
trans(np.ones([4, 4, 1], np.int64))
176+
trans(np.ones([4, 4, 1], np.uint16))
177+
trans(np.ones([4, 4, 1], np.uint32))
178+
trans(np.ones([4, 4, 1], np.float64))
179+
180+
def test_ndarray_gray_float32_to_pil_image(self):
181+
trans = transforms.ToPILImage()
182+
img_data = torch.FloatTensor(4, 4, 1).random_().numpy()
183+
img = trans(img_data)
184+
assert img.mode == 'F'
185+
assert np.allclose(img, img_data[:, :, 0])
186+
187+
def test_ndarray_gray_int16_to_pil_image(self):
188+
trans = transforms.ToPILImage()
189+
img_data = torch.ShortTensor(4, 4, 1).random_().numpy()
175190
img = trans(img_data)
176191
assert img.mode == 'I;16'
177192
assert np.allclose(img, img_data[:, :, 0])
178193

194+
def test_ndarray_gray_int32_to_pil_image(self):
195+
trans = transforms.ToPILImage()
196+
img_data = torch.IntTensor(4, 4, 1).random_().numpy()
197+
img = trans(img_data)
198+
assert img.mode == 'I'
199+
assert np.allclose(img, img_data[:, :, 0])
200+
179201

180202
if __name__ == '__main__':
181203
unittest.main()

torchvision/transforms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ def __call__(self, pic):
7373

7474
if npimg.dtype == np.uint8:
7575
mode = 'L'
76-
if npimg.dtype == np.uint16:
76+
if npimg.dtype == np.int16:
7777
mode = 'I;16'
78+
if npimg.dtype == np.int32:
79+
mode = 'I'
7880
elif npimg.dtype == np.float32:
7981
mode = 'F'
8082
else:

0 commit comments

Comments
 (0)