Skip to content

Commit 991bad2

Browse files
bodokaisersoumith
authored andcommitted
updated ToTensor to support more types
1 parent 6cbb22b commit 991bad2

File tree

2 files changed

+46
-13
lines changed

2 files changed

+46
-13
lines changed

test/test_transforms.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,30 @@ def test_tensor_to_pil_image(self):
151151
expected_output = img_data.mul(255).int().float().div(255)
152152
assert np.allclose(expected_output[0].numpy(), to_tensor(l).numpy())
153153

154+
def test_tensor_gray_to_pil_image(self):
155+
trans = transforms.ToPILImage()
156+
to_tensor = transforms.ToTensor()
157+
158+
img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255)
159+
img_data_short = torch.ShortTensor(1, 4, 4).random_()
160+
img_data_int = torch.IntTensor(1, 4, 4).random_()
161+
img_data_float = torch.FloatTensor(1, 4, 4).uniform_()
162+
163+
img_byte = trans(img_data_byte)
164+
img_short = trans(img_data_short)
165+
img_int = trans(img_data_int)
166+
img_float = trans(img_data_float)
167+
assert img_byte.mode == 'L'
168+
assert img_short.mode == 'I;16'
169+
assert img_int.mode == 'I'
170+
#assert img_float.mode == 'F'
171+
172+
assert np.allclose(img_data_short.numpy(), to_tensor(img_short).numpy())
173+
assert np.allclose(img_data_int.numpy(), to_tensor(img_int).numpy())
174+
# would cause breaking changes as ToTensor converts to range [0, 1]
175+
#assert np.allclose(img_data_byte.numpy(), to_tensor(img_byte).numpy())
176+
#assert np.allclose(img_data_float.numpy(), to_tensor(img_float).numpy())
177+
154178
def test_ndarray_to_pil_image(self):
155179
trans = transforms.ToPILImage()
156180
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()

torchvision/transforms.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,30 @@ def __call__(self, pic):
3939
if isinstance(pic, np.ndarray):
4040
# handle numpy array
4141
img = torch.from_numpy(pic.transpose((2, 0, 1)))
42+
# backard compability
43+
return img.float().div(255)
44+
# handle PIL Image
45+
if pic.mode == 'I':
46+
img = torch.from_numpy(np.array(pic, np.int32))
47+
elif pic.mode == 'I;16':
48+
img = torch.from_numpy(np.array(pic, np.int16))
4249
else:
43-
# handle PIL Image
4450
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
45-
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
46-
if pic.mode == 'YCbCr':
47-
nchannel = 3
48-
else:
49-
nchannel = len(pic.mode)
50-
img = img.view(pic.size[1], pic.size[0], nchannel)
51-
# put it from HWC to CHW format
52-
# yikes, this transpose takes 80% of the loading time/CPU
53-
img = img.transpose(0, 1).transpose(0, 2).contiguous()
54-
return img.float().div(255)
51+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
52+
if pic.mode == 'YCbCr':
53+
nchannel = 3
54+
elif pic.mode == 'I;16':
55+
nchannel = 1
56+
else:
57+
nchannel = len(pic.mode)
58+
img = img.view(pic.size[1], pic.size[0], nchannel)
59+
# put it from HWC to CHW format
60+
# yikes, this transpose takes 80% of the loading time/CPU
61+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
62+
if isinstance(img, torch.ByteTensor):
63+
return img.float().div(255)
64+
else:
65+
return img
5566

5667

5768
class ToPILImage(object):
@@ -67,7 +78,6 @@ def __call__(self, pic):
6778
if torch.is_tensor(pic):
6879
npimg = np.transpose(pic.numpy(), (1, 2, 0))
6980
assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
70-
7181
if npimg.shape[2] == 1:
7282
npimg = npimg[:, :, 0]
7383

@@ -83,7 +93,6 @@ def __call__(self, pic):
8393
if npimg.dtype == np.uint8:
8494
mode = 'RGB'
8595
assert mode is not None, '{} is not supported'.format(npimg.dtype)
86-
8796
return Image.fromarray(npimg, mode=mode)
8897

8998

0 commit comments

Comments
 (0)