@@ -104,6 +104,8 @@ def to_tensor(pic):
104
104
if _is_numpy (pic ) and not _is_numpy_image (pic ):
105
105
raise ValueError ('pic should be 2/3 dimensional. Got {} dimensions.' .format (pic .ndim ))
106
106
107
+ default_float_dtype = torch .get_default_dtype ()
108
+
107
109
if isinstance (pic , np .ndarray ):
108
110
# handle numpy array
109
111
if pic .ndim == 2 :
@@ -112,12 +114,12 @@ def to_tensor(pic):
112
114
img = torch .from_numpy (pic .transpose ((2 , 0 , 1 ))).contiguous ()
113
115
# backward compatibility
114
116
if isinstance (img , torch .ByteTensor ):
115
- return img .float ( ).div (255 )
117
+ return img .to ( dtype = default_float_dtype ).div (255 )
116
118
else :
117
119
return img
118
120
119
121
if accimage is not None and isinstance (pic , accimage .Image ):
120
- nppic = np .zeros ([pic .channels , pic .height , pic .width ], dtype = np . float32 )
122
+ nppic = np .zeros ([pic .channels , pic .height , pic .width ], dtype = default_float_dtype )
121
123
pic .copyto (nppic )
122
124
return torch .from_numpy (nppic )
123
125
@@ -137,7 +139,7 @@ def to_tensor(pic):
137
139
# put it from HWC to CHW format
138
140
img = img .permute ((2 , 0 , 1 )).contiguous ()
139
141
if isinstance (img , torch .ByteTensor ):
140
- return img .float ( ).div (255 )
142
+ return img .to ( dtype = default_float_dtype ).div (255 )
141
143
else :
142
144
return img
143
145
0 commit comments