Skip to content

Commit 067b9dc

Browse files
authored
Functional to_tensor returns float tensor of default dtype (#3398)
Fixes #3393
1 parent f04e9cb commit 067b9dc

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

test/test_transforms.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,20 @@ def test_to_tensor(self):
620620
output = trans(img)
621621
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
622622

623+
def test_to_tensor_with_other_default_dtypes(self):
624+
current_def_dtype = torch.get_default_dtype()
625+
626+
t = transforms.ToTensor()
627+
np_arr = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
628+
img = Image.fromarray(np_arr)
629+
630+
for dtype in [torch.float16, torch.float, torch.double]:
631+
torch.set_default_dtype(dtype)
632+
res = t(img)
633+
self.assertTrue(res.dtype == dtype, msg=f"{res.dtype} vs {dtype}")
634+
635+
torch.set_default_dtype(current_def_dtype)
636+
623637
def test_max_value(self):
624638
for dtype in int_dtypes():
625639
self.assertEqual(F_t._max_value(dtype), torch.iinfo(dtype).max)

torchvision/transforms/functional.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def to_tensor(pic):
104104
if _is_numpy(pic) and not _is_numpy_image(pic):
105105
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
106106

107+
default_float_dtype = torch.get_default_dtype()
108+
107109
if isinstance(pic, np.ndarray):
108110
# handle numpy array
109111
if pic.ndim == 2:
@@ -112,12 +114,12 @@ def to_tensor(pic):
112114
img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
113115
# backward compatibility
114116
if isinstance(img, torch.ByteTensor):
115-
return img.float().div(255)
117+
return img.to(dtype=default_float_dtype).div(255)
116118
else:
117119
return img
118120

119121
if accimage is not None and isinstance(pic, accimage.Image):
120-
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
122+
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=default_float_dtype)
121123
pic.copyto(nppic)
122124
return torch.from_numpy(nppic)
123125

@@ -137,7 +139,7 @@ def to_tensor(pic):
137139
# put it from HWC to CHW format
138140
img = img.permute((2, 0, 1)).contiguous()
139141
if isinstance(img, torch.ByteTensor):
140-
return img.float().div(255)
142+
return img.to(dtype=default_float_dtype).div(255)
141143
else:
142144
return img
143145

0 commit comments

Comments
 (0)