Skip to content

Commit 50216f7

Browse files
committed
Solution lifted from implementation of fromarray()
1 parent 8345753 commit 50216f7

File tree

1 file changed

+41
-1
lines changed

1 file changed

+41
-1
lines changed

torchvision/utils.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,47 @@ def _Image_fromarray(
186186
https://pillow.readthedocs.io/en/stable/releasenotes/11.3.0.html#image-fromarray-mode-parameter
187187
"""
188188
if PILLOW_VERSION >= (11, 3):
189-
return Image.fromarray(obj)
189+
# We actually rely on the old behavior of Image.fromarray():
190+
#
191+
# new behavior: PIL will infer the image mode from the data passed in.
192+
# That is, the type and shape determines the mode.
193+
#
194+
# old behiavor: The mode will change how PIL reads the image,
195+
# regardless of the data. That is, it will make the data
196+
# work with the mode.
197+
#
198+
# Our uses of Image.fromarray() are effectively a "turn into PIL image
199+
# AND convert the kind" operation. In particular, in
200+
# functional.to_pil_image() and transforms.ToPILImage.
201+
#
202+
# However, Image.frombuffer() still performs this conversion. The code
203+
# below is lifted from the new implementation of Image.fromarray(). We
204+
# omit the code that infers the mode, and use the code that figures out
205+
# from the data passed in (obj) what the correct parameters are to
206+
# Image.frombuffer().
207+
#
208+
# Note that the alternate solution below does not work:
209+
#
210+
# img = Image.fromarray(obj)
211+
# img = img.convert(mode)
212+
#
213+
# The resulting image has very different actual pixel values than before.
214+
arr = obj.__array_interface__
215+
shape = arr["shape"]
216+
ndim = len(shape)
217+
size = 1 if ndim == 1 else shape[1], shape[0]
218+
219+
strides = arr.get("strides", None)
220+
if strides is not None:
221+
if hasattr(obj, "tobytes"):
222+
obj = obj.tobytes()
223+
elif hasattr(obj, "tostring"):
224+
obj = obj.tostring()
225+
else:
226+
msg = "'strides' requires either tobytes() or tostring()"
227+
raise ValueError(msg)
228+
229+
return Image.frombuffer(mode, size, obj, "raw", mode, 0, 1)
190230
else:
191231
return Image.fromarray(obj, mode)
192232

0 commit comments

Comments
 (0)