Skip to content

Commit f6ab107

Browse files
arturmlfmassa
authored andcommitted
Add support in transforms.ToTensor for PIL Images mod '1' (#471)
* Add case in test_to_tensor for PIL Images mode '1' * Add support in ToTensor for PIL Images mode '1' * Fix pep8 issues
1 parent 7bda0e8 commit f6ab107

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

test/test_transforms.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,12 @@ def test_to_tensor(self):
404404
expected_output = ndarray.transpose((2, 0, 1))
405405
assert np.allclose(output.numpy(), expected_output)
406406

407+
# separate test for mode '1' PIL images
408+
input_data = torch.ByteTensor(1, height, width).bernoulli_()
409+
img = transforms.ToPILImage()(input_data.mul(255)).convert('1')
410+
output = trans(img)
411+
assert np.allclose(input_data.numpy(), output.numpy())
412+
407413
@unittest.skipIf(accimage is None, 'accimage not available')
408414
def test_accimage_to_tensor(self):
409415
trans = transforms.ToTensor()

torchvision/transforms/functional.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@ def to_tensor(pic):
6464
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
6565
elif pic.mode == 'F':
6666
img = torch.from_numpy(np.array(pic, np.float32, copy=False))
67+
elif pic.mode == '1':
68+
img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
6769
else:
6870
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
69-
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
71+
# PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK
7072
if pic.mode == 'YCbCr':
7173
nchannel = 3
7274
elif pic.mode == 'I;16':

0 commit comments

Comments
 (0)