Skip to content

Commit 7d5c07d

Browse files
committed
Better types
1 parent 50216f7 commit 7d5c07d

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

torchvision/utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def dashed_line(self, xy, fill=None, width=0, joint=None, dash_length=5, space_l
178178

179179
def _Image_fromarray(
180180
obj: np.ndarray,
181-
mode: Optional[str],
181+
mode: str,
182182
) -> Image.Image:
183183
"""
184184
A wrapper around PIL.Image.fromarray to mitigate the deprecation of the
@@ -217,16 +217,18 @@ def _Image_fromarray(
217217
size = 1 if ndim == 1 else shape[1], shape[0]
218218

219219
strides = arr.get("strides", None)
220+
contiguous_obj: Union[np.ndarray, bytes] = obj
220221
if strides is not None:
222+
# We require that the data is contiguous; if it is not, we need to
223+
# convert it into a contiguous format.
221224
if hasattr(obj, "tobytes"):
222-
obj = obj.tobytes()
225+
contiguous_obj = obj.tobytes()
223226
elif hasattr(obj, "tostring"):
224-
obj = obj.tostring()
227+
contiguous_obj = obj.tostring()
225228
else:
226-
msg = "'strides' requires either tobytes() or tostring()"
227-
raise ValueError(msg)
229+
raise ValueError("Unable to convert obj into contiguous format")
228230

229-
return Image.frombuffer(mode, size, obj, "raw", mode, 0, 1)
231+
return Image.frombuffer(mode, size, contiguous_obj, "raw", mode, 0, 1)
230232
else:
231233
return Image.fromarray(obj, mode)
232234

0 commit comments

Comments
 (0)