Skip to content

[release/0.23] Cherry-pick PIL mode mitigation #9153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ def get_dist(pkgname):
]

# Excluding 8.3.* because of https://github.com/pytorch/vision/issues/4934
# TODO remove <11.3 bound and address corresponding deprecation warnings
pillow_ver = " >= 5.3.0, !=8.3.*, <11.3"
pillow_ver = " >= 5.3.0, !=8.3.*"
pillow_req = "pillow-simd" if get_dist("pillow-simd") is not None else "pillow"
requirements.append(pillow_req + pillow_ver)

Expand Down
6 changes: 3 additions & 3 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
from subprocess import CalledProcessError, check_output, STDOUT

import numpy as np
import PIL.Image
import PIL
import pytest
import torch
import torch.testing
from PIL import Image

from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import io, tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_image, to_pil_image
from torchvision.utils import _Image_fromarray


IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
Expand Down Expand Up @@ -147,7 +147,7 @@ def _create_data(height=3, width=3, channels=3, device="cpu"):
if channels == 1:
mode = "L"
data = data[..., 0]
pil_img = Image.fromarray(data, mode=mode)
pil_img = _Image_fromarray(data, mode=mode)
return tensor, pil_img


Expand Down
21 changes: 11 additions & 10 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torchvision.transforms.functional as F
from PIL import Image
from torch._utils_internal import get_file_path_2
from torchvision.utils import _Image_fromarray

try:
import accimage
Expand Down Expand Up @@ -654,7 +655,7 @@ def test_1_channel_float_tensor_to_pil_image(self):
img_F_mode = transforms.ToPILImage(mode="F")(img_data)
assert img_F_mode.mode == "F"
torch.testing.assert_close(
np.array(Image.fromarray(img_data.squeeze(0).numpy(), mode="F")), np.array(img_F_mode)
np.array(_Image_fromarray(img_data.squeeze(0).numpy(), mode="F")), np.array(img_F_mode)
)

@pytest.mark.parametrize("with_mode", [False, True])
Expand Down Expand Up @@ -895,7 +896,7 @@ def test_adjust_brightness():
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode="RGB")
x_pil = _Image_fromarray(x_np, mode="RGB")

# test 0
y_pil = F.adjust_brightness(x_pil, 1)
Expand All @@ -921,7 +922,7 @@ def test_adjust_contrast():
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode="RGB")
x_pil = _Image_fromarray(x_np, mode="RGB")

# test 0
y_pil = F.adjust_contrast(x_pil, 1)
Expand All @@ -947,7 +948,7 @@ def test_adjust_hue():
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode="RGB")
x_pil = _Image_fromarray(x_np, mode="RGB")

with pytest.raises(ValueError):
F.adjust_hue(x_pil, -0.7)
Expand Down Expand Up @@ -1029,7 +1030,7 @@ def test_adjust_sharpness():
117,
]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode="RGB")
x_pil = _Image_fromarray(x_np, mode="RGB")

# test 0
y_pil = F.adjust_sharpness(x_pil, 1)
Expand Down Expand Up @@ -1152,7 +1153,7 @@ def test_adjust_sharpness():
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode="RGB")
x_pil = _Image_fromarray(x_np, mode="RGB")
x_th = torch.tensor(x_np.transpose(2, 0, 1))
y_pil = F.adjust_sharpness(x_pil, 2)
y_np = np.array(y_pil).transpose(2, 0, 1)
Expand All @@ -1164,7 +1165,7 @@ def test_adjust_gamma():
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode="RGB")
x_pil = _Image_fromarray(x_np, mode="RGB")

# test 0
y_pil = F.adjust_gamma(x_pil, 1)
Expand All @@ -1190,7 +1191,7 @@ def test_adjusts_L_mode():
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_rgb = Image.fromarray(x_np, mode="RGB")
x_rgb = _Image_fromarray(x_np, mode="RGB")

x_l = x_rgb.convert("L")
assert F.adjust_brightness(x_l, 2).mode == "L"
Expand Down Expand Up @@ -1320,7 +1321,7 @@ def test_to_grayscale():
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode="RGB")
x_pil = _Image_fromarray(x_np, mode="RGB")
x_pil_2 = x_pil.convert("L")
gray_np = np.array(x_pil_2)

Expand Down Expand Up @@ -1769,7 +1770,7 @@ def test_color_jitter():
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode="RGB")
x_pil = _Image_fromarray(x_np, mode="RGB")
x_pil_2 = x_pil.convert("L")

for _ in range(10):
Expand Down
6 changes: 3 additions & 3 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

import numpy as np
import torch
from PIL import Image

from ..utils import _Image_fromarray
from .utils import _flip_byte_order, check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
from .vision import VisionDataset

Expand Down Expand Up @@ -140,7 +140,7 @@ def __getitem__(self, index: int) -> tuple[Any, Any]:

# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode="L")
img = _Image_fromarray(img.numpy(), mode="L")

if self.transform is not None:
img = self.transform(img)
Expand Down Expand Up @@ -478,7 +478,7 @@ def download(self) -> None:
def __getitem__(self, index: int) -> tuple[Any, Any]:
# redefined to handle the compat flag
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img.numpy(), mode="L")
img = _Image_fromarray(img.numpy(), mode="L")
if self.transform is not None:
img = self.transform(img)
if self.compat:
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/semeion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Any, Callable, Optional, Union

import numpy as np
from PIL import Image

from ..utils import _Image_fromarray
from .utils import check_integrity, download_url
from .vision import VisionDataset

Expand Down Expand Up @@ -64,7 +64,7 @@ def __getitem__(self, index: int) -> tuple[Any, Any]:

# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img, mode="L")
img = _Image_fromarray(img, mode="L")

if self.transform is not None:
img = self.transform(img)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/usps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Any, Callable, Optional, Union

import numpy as np
from PIL import Image

from ..utils import _Image_fromarray
from .utils import download_url
from .vision import VisionDataset

Expand Down Expand Up @@ -82,7 +82,7 @@ def __getitem__(self, index: int) -> tuple[Any, Any]:

# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img, mode="L")
img = _Image_fromarray(img, mode="L")

if self.transform is not None:
img = self.transform(img)
Expand Down
6 changes: 4 additions & 2 deletions torchvision/transforms/_functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch
from PIL import Image, ImageEnhance, ImageOps

from ..utils import _Image_fromarray

try:
import accimage
except ImportError:
Expand Down Expand Up @@ -113,7 +115,7 @@ def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
# This will over/underflow, as desired
np_h += np.int32(hue_factor * 255).astype(np.uint8)

h = Image.fromarray(np_h, "L")
h = _Image_fromarray(np_h, "L")

img = Image.merge("HSV", (h, s, v)).convert(input_mode)
return img
Expand Down Expand Up @@ -342,7 +344,7 @@ def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
img = img.convert("L")
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, "RGB")
img = _Image_fromarray(np_img, "RGB")
else:
raise ValueError("num_output_channels should be either 1 or 3")

Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
except ImportError:
accimage = None

from ..utils import _log_api_usage_once
from ..utils import _Image_fromarray, _log_api_usage_once
from . import _functional_pil as F_pil, _functional_tensor as F_t


Expand Down Expand Up @@ -321,7 +321,7 @@ def to_pil_image(pic, mode=None):
if mode is None:
raise TypeError(f"Input type {npimg.dtype} is not supported")

return Image.fromarray(npimg, mode=mode)
return _Image_fromarray(npimg, mode=mode)


def normalize(tensor: Tensor, mean: list[float], std: list[float], inplace: bool = False) -> Tensor:
Expand Down
82 changes: 81 additions & 1 deletion torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

import numpy as np
import torch
from PIL import Image, ImageColor, ImageDraw, ImageFont
from PIL import __version__ as PILLOW_VERSION_STRING, Image, ImageColor, ImageDraw, ImageFont


__all__ = [
"_Image_fromarray",
"make_grid",
"save_image",
"draw_bounding_boxes",
Expand Down Expand Up @@ -174,6 +175,85 @@ def dashed_line(self, xy, fill=None, width=0, joint=None, dash_length=5, space_l
current_dash = not current_dash


def _Image_fromarray(
obj: np.ndarray,
mode: str,
) -> Image.Image:
"""
A wrapper around PIL.Image.fromarray to mitigate the deprecation of the
mode paramter. See:
https://pillow.readthedocs.io/en/stable/releasenotes/11.3.0.html#image-fromarray-mode-parameter
"""

# This may throw if the version string is from an install that comes from a
# non-stable or development version. We'll fall back to the old behavior in
# such cases.
try:
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION_STRING.split("."))
except Exception:
PILLOW_VERSION = None

if PILLOW_VERSION is not None and PILLOW_VERSION >= (11, 3):
# The actual PR that implements the deprecation has more context for why
# it was done, and also points out some problems:
#
# https://github.com/python-pillow/Pillow/pull/9018
#
# Our use case falls into those problems. We actually rely on the old
# behavior of Image.fromarray():
#
# new behavior: PIL will infer the image mode from the data passed
# in. That is, the type and shape determines the mode.
#
# old behiavor: The mode will change how PIL reads the image,
# regardless of the data. That is, it will make the
# data work with the mode.
#
# Our uses of Image.fromarray() are effectively a "turn into PIL image
# AND convert the kind" operation. In particular, in
# functional.to_pil_image() and transforms.ToPILImage.
#
# However, Image.frombuffer() still performs this conversion. The code
# below is lifted from the new implementation of Image.fromarray(). We
# omit the code that infers the mode, and use the code that figures out
# from the data passed in (obj) what the correct parameters are to
# Image.frombuffer().
#
# Note that the alternate solution below does not work:
#
# img = Image.fromarray(obj)
# img = img.convert(mode)
#
# The resulting image has very different actual pixel values than before.
#
# TODO: Issue #9151. Pillow has an open PR to restore the functionality
# we rely on:
#
# https://github.com/python-pillow/Pillow/pull/9063
#
# When that is part of a release, we can revisit this hack below.
arr = obj.__array_interface__
shape = arr["shape"]
ndim = len(shape)
size = 1 if ndim == 1 else shape[1], shape[0]

strides = arr.get("strides", None)
contiguous_obj: Union[np.ndarray, bytes] = obj
if strides is not None:
# We require that the data is contiguous; if it is not, we need to
# convert it into a contiguous format.
if hasattr(obj, "tobytes"):
contiguous_obj = obj.tobytes()
elif hasattr(obj, "tostring"):
contiguous_obj = obj.tostring()
else:
raise ValueError("Unable to convert obj into contiguous format")

return Image.frombuffer(mode, size, contiguous_obj, "raw", mode, 0, 1)
else:
return Image.fromarray(obj, mode)


@torch.no_grad()
def save_image(
tensor: Union[torch.Tensor, list[torch.Tensor]],
Expand Down
Loading