-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Mitigate PIL Image.fromarray() mode deprecation #9150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
b079a96
425a923
1dfd59d
044eacd
005a424
882eb31
8345753
50216f7
7d5c07d
1594569
4ace920
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,10 +8,12 @@ | |
|
||
import numpy as np | ||
import torch | ||
from PIL import Image, ImageColor, ImageDraw, ImageFont | ||
from PIL import __version__ as PILLOW_VERSION_STRING, Image, ImageColor, ImageDraw, ImageFont | ||
|
||
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION_STRING.split(".")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I stole this from our testing code. I'm doing this instead of adding a dependence on the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed not to add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the case that it fails, I'm going to propose we fall back to the old behavior since it's like more people will have the old PIL versions. |
||
|
||
__all__ = [ | ||
"_Image_fromarray", | ||
"make_grid", | ||
"save_image", | ||
"draw_bounding_boxes", | ||
|
@@ -174,6 +176,63 @@ def dashed_line(self, xy, fill=None, width=0, joint=None, dash_length=5, space_l | |
current_dash = not current_dash | ||
|
||
|
||
def _Image_fromarray( | ||
obj: np.ndarray, | ||
mode: str, | ||
) -> Image.Image: | ||
""" | ||
A wrapper around PIL.Image.fromarray to mitigate the deprecation of the | ||
mode paramter. See: | ||
https://pillow.readthedocs.io/en/stable/releasenotes/11.3.0.html#image-fromarray-mode-parameter | ||
""" | ||
if PILLOW_VERSION >= (11, 3): | ||
# We actually rely on the old behavior of Image.fromarray(): | ||
# | ||
# new behavior: PIL will infer the image mode from the data passed in. | ||
# That is, the type and shape determines the mode. | ||
# | ||
# old behiavor: The mode will change how PIL reads the image, | ||
# regardless of the data. That is, it will make the data | ||
# work with the mode. | ||
# | ||
# Our uses of Image.fromarray() are effectively a "turn into PIL image | ||
# AND convert the kind" operation. In particular, in | ||
# functional.to_pil_image() and transforms.ToPILImage. | ||
# | ||
# However, Image.frombuffer() still performs this conversion. The code | ||
# below is lifted from the new implementation of Image.fromarray(). We | ||
# omit the code that infers the mode, and use the code that figures out | ||
# from the data passed in (obj) what the correct parameters are to | ||
# Image.frombuffer(). | ||
# | ||
# Note that the alternate solution below does not work: | ||
# | ||
# img = Image.fromarray(obj) | ||
# img = img.convert(mode) | ||
Comment on lines
+224
to
+225
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wellll that sucks. I was really hoping this would just be it. I personally don't even understand the claim that the mode can be inferred from the shape and dtype. Looking at the docs on mode https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-modes there seem to be plenty of opportunity for collisions:
I took a brief look at python-pillow/Pillow#9018, apparently Thank you for the fix @scotts ! Maybe we should add a comment to follow-up on this with a link to python-pillow/Pillow#9018 and python-pillow/Pillow#9063 |
||
# | ||
# The resulting image has very different actual pixel values than before. | ||
arr = obj.__array_interface__ | ||
shape = arr["shape"] | ||
ndim = len(shape) | ||
size = 1 if ndim == 1 else shape[1], shape[0] | ||
|
||
strides = arr.get("strides", None) | ||
contiguous_obj: Union[np.ndarray, bytes] = obj | ||
if strides is not None: | ||
# We require that the data is contiguous; if it is not, we need to | ||
# convert it into a contiguous format. | ||
if hasattr(obj, "tobytes"): | ||
contiguous_obj = obj.tobytes() | ||
elif hasattr(obj, "tostring"): | ||
contiguous_obj = obj.tostring() | ||
else: | ||
raise ValueError("Unable to convert obj into contiguous format") | ||
|
||
return Image.frombuffer(mode, size, contiguous_obj, "raw", mode, 0, 1) | ||
else: | ||
return Image.fromarray(obj, mode) | ||
|
||
|
||
@torch.no_grad() | ||
def save_image( | ||
tensor: Union[torch.Tensor, list[torch.Tensor]], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you might want to remove the "todo" line.
PS. I'm just a bystander that got affect by these tests failures.