Skip to content

Commit 824e8c8

Browse files
[release/0.23] Cherry-pick PIL mode mitigation (#9153)
Co-authored-by: Scott Schneider <[email protected]>
1 parent 9a8003e commit 824e8c8

File tree

9 files changed

+109
-27
lines changed

9 files changed

+109
-27
lines changed

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,7 @@ def get_dist(pkgname):
111111
]
112112

113113
# Excluding 8.3.* because of https://github.com/pytorch/vision/issues/4934
114-
# TODO remove <11.3 bound and address corresponding deprecation warnings
115-
pillow_ver = " >= 5.3.0, !=8.3.*, <11.3"
114+
pillow_ver = " >= 5.3.0, !=8.3.*"
116115
pillow_req = "pillow-simd" if get_dist("pillow-simd") is not None else "pillow"
117116
requirements.append(pillow_req + pillow_ver)
118117

test/common_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
from subprocess import CalledProcessError, check_output, STDOUT
1313

1414
import numpy as np
15-
import PIL.Image
15+
import PIL
1616
import pytest
1717
import torch
1818
import torch.testing
19-
from PIL import Image
2019

2120
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
2221
from torchvision import io, tv_tensors
2322
from torchvision.transforms._functional_tensor import _max_value as get_max_value
2423
from torchvision.transforms.v2.functional import to_image, to_pil_image
24+
from torchvision.utils import _Image_fromarray
2525

2626

2727
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
@@ -147,7 +147,7 @@ def _create_data(height=3, width=3, channels=3, device="cpu"):
147147
if channels == 1:
148148
mode = "L"
149149
data = data[..., 0]
150-
pil_img = Image.fromarray(data, mode=mode)
150+
pil_img = _Image_fromarray(data, mode=mode)
151151
return tensor, pil_img
152152

153153

test/test_transforms.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torchvision.transforms.functional as F
1414
from PIL import Image
1515
from torch._utils_internal import get_file_path_2
16+
from torchvision.utils import _Image_fromarray
1617

1718
try:
1819
import accimage
@@ -654,7 +655,7 @@ def test_1_channel_float_tensor_to_pil_image(self):
654655
img_F_mode = transforms.ToPILImage(mode="F")(img_data)
655656
assert img_F_mode.mode == "F"
656657
torch.testing.assert_close(
657-
np.array(Image.fromarray(img_data.squeeze(0).numpy(), mode="F")), np.array(img_F_mode)
658+
np.array(_Image_fromarray(img_data.squeeze(0).numpy(), mode="F")), np.array(img_F_mode)
658659
)
659660

660661
@pytest.mark.parametrize("with_mode", [False, True])
@@ -895,7 +896,7 @@ def test_adjust_brightness():
895896
x_shape = [2, 2, 3]
896897
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
897898
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
898-
x_pil = Image.fromarray(x_np, mode="RGB")
899+
x_pil = _Image_fromarray(x_np, mode="RGB")
899900

900901
# test 0
901902
y_pil = F.adjust_brightness(x_pil, 1)
@@ -921,7 +922,7 @@ def test_adjust_contrast():
921922
x_shape = [2, 2, 3]
922923
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
923924
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
924-
x_pil = Image.fromarray(x_np, mode="RGB")
925+
x_pil = _Image_fromarray(x_np, mode="RGB")
925926

926927
# test 0
927928
y_pil = F.adjust_contrast(x_pil, 1)
@@ -947,7 +948,7 @@ def test_adjust_hue():
947948
x_shape = [2, 2, 3]
948949
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
949950
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
950-
x_pil = Image.fromarray(x_np, mode="RGB")
951+
x_pil = _Image_fromarray(x_np, mode="RGB")
951952

952953
with pytest.raises(ValueError):
953954
F.adjust_hue(x_pil, -0.7)
@@ -1029,7 +1030,7 @@ def test_adjust_sharpness():
10291030
117,
10301031
]
10311032
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1032-
x_pil = Image.fromarray(x_np, mode="RGB")
1033+
x_pil = _Image_fromarray(x_np, mode="RGB")
10331034

10341035
# test 0
10351036
y_pil = F.adjust_sharpness(x_pil, 1)
@@ -1152,7 +1153,7 @@ def test_adjust_sharpness():
11521153
x_shape = [2, 2, 3]
11531154
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
11541155
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1155-
x_pil = Image.fromarray(x_np, mode="RGB")
1156+
x_pil = _Image_fromarray(x_np, mode="RGB")
11561157
x_th = torch.tensor(x_np.transpose(2, 0, 1))
11571158
y_pil = F.adjust_sharpness(x_pil, 2)
11581159
y_np = np.array(y_pil).transpose(2, 0, 1)
@@ -1164,7 +1165,7 @@ def test_adjust_gamma():
11641165
x_shape = [2, 2, 3]
11651166
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
11661167
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1167-
x_pil = Image.fromarray(x_np, mode="RGB")
1168+
x_pil = _Image_fromarray(x_np, mode="RGB")
11681169

11691170
# test 0
11701171
y_pil = F.adjust_gamma(x_pil, 1)
@@ -1190,7 +1191,7 @@ def test_adjusts_L_mode():
11901191
x_shape = [2, 2, 3]
11911192
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
11921193
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1193-
x_rgb = Image.fromarray(x_np, mode="RGB")
1194+
x_rgb = _Image_fromarray(x_np, mode="RGB")
11941195

11951196
x_l = x_rgb.convert("L")
11961197
assert F.adjust_brightness(x_l, 2).mode == "L"
@@ -1320,7 +1321,7 @@ def test_to_grayscale():
13201321
x_shape = [2, 2, 3]
13211322
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
13221323
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1323-
x_pil = Image.fromarray(x_np, mode="RGB")
1324+
x_pil = _Image_fromarray(x_np, mode="RGB")
13241325
x_pil_2 = x_pil.convert("L")
13251326
gray_np = np.array(x_pil_2)
13261327

@@ -1769,7 +1770,7 @@ def test_color_jitter():
17691770
x_shape = [2, 2, 3]
17701771
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
17711772
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1772-
x_pil = Image.fromarray(x_np, mode="RGB")
1773+
x_pil = _Image_fromarray(x_np, mode="RGB")
17731774
x_pil_2 = x_pil.convert("L")
17741775

17751776
for _ in range(10):

torchvision/datasets/mnist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
import numpy as np
1313
import torch
14-
from PIL import Image
1514

15+
from ..utils import _Image_fromarray
1616
from .utils import _flip_byte_order, check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
1717
from .vision import VisionDataset
1818

@@ -140,7 +140,7 @@ def __getitem__(self, index: int) -> tuple[Any, Any]:
140140

141141
# doing this so that it is consistent with all other datasets
142142
# to return a PIL Image
143-
img = Image.fromarray(img.numpy(), mode="L")
143+
img = _Image_fromarray(img.numpy(), mode="L")
144144

145145
if self.transform is not None:
146146
img = self.transform(img)
@@ -478,7 +478,7 @@ def download(self) -> None:
478478
def __getitem__(self, index: int) -> tuple[Any, Any]:
479479
# redefined to handle the compat flag
480480
img, target = self.data[index], self.targets[index]
481-
img = Image.fromarray(img.numpy(), mode="L")
481+
img = _Image_fromarray(img.numpy(), mode="L")
482482
if self.transform is not None:
483483
img = self.transform(img)
484484
if self.compat:

torchvision/datasets/semeion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from typing import Any, Callable, Optional, Union
44

55
import numpy as np
6-
from PIL import Image
76

7+
from ..utils import _Image_fromarray
88
from .utils import check_integrity, download_url
99
from .vision import VisionDataset
1010

@@ -64,7 +64,7 @@ def __getitem__(self, index: int) -> tuple[Any, Any]:
6464

6565
# doing this so that it is consistent with all other datasets
6666
# to return a PIL Image
67-
img = Image.fromarray(img, mode="L")
67+
img = _Image_fromarray(img, mode="L")
6868

6969
if self.transform is not None:
7070
img = self.transform(img)

torchvision/datasets/usps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from typing import Any, Callable, Optional, Union
44

55
import numpy as np
6-
from PIL import Image
76

7+
from ..utils import _Image_fromarray
88
from .utils import download_url
99
from .vision import VisionDataset
1010

@@ -82,7 +82,7 @@ def __getitem__(self, index: int) -> tuple[Any, Any]:
8282

8383
# doing this so that it is consistent with all other datasets
8484
# to return a PIL Image
85-
img = Image.fromarray(img, mode="L")
85+
img = _Image_fromarray(img, mode="L")
8686

8787
if self.transform is not None:
8888
img = self.transform(img)

torchvision/transforms/_functional_pil.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch
77
from PIL import Image, ImageEnhance, ImageOps
88

9+
from ..utils import _Image_fromarray
10+
911
try:
1012
import accimage
1113
except ImportError:
@@ -113,7 +115,7 @@ def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
113115
# This will over/underflow, as desired
114116
np_h += np.int32(hue_factor * 255).astype(np.uint8)
115117

116-
h = Image.fromarray(np_h, "L")
118+
h = _Image_fromarray(np_h, "L")
117119

118120
img = Image.merge("HSV", (h, s, v)).convert(input_mode)
119121
return img
@@ -342,7 +344,7 @@ def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
342344
img = img.convert("L")
343345
np_img = np.array(img, dtype=np.uint8)
344346
np_img = np.dstack([np_img, np_img, np_img])
345-
img = Image.fromarray(np_img, "RGB")
347+
img = _Image_fromarray(np_img, "RGB")
346348
else:
347349
raise ValueError("num_output_channels should be either 1 or 3")
348350

torchvision/transforms/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
except ImportError:
1717
accimage = None
1818

19-
from ..utils import _log_api_usage_once
19+
from ..utils import _Image_fromarray, _log_api_usage_once
2020
from . import _functional_pil as F_pil, _functional_tensor as F_t
2121

2222

@@ -321,7 +321,7 @@ def to_pil_image(pic, mode=None):
321321
if mode is None:
322322
raise TypeError(f"Input type {npimg.dtype} is not supported")
323323

324-
return Image.fromarray(npimg, mode=mode)
324+
return _Image_fromarray(npimg, mode=mode)
325325

326326

327327
def normalize(tensor: Tensor, mean: list[float], std: list[float], inplace: bool = False) -> Tensor:

torchvision/utils.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
import numpy as np
1010
import torch
11-
from PIL import Image, ImageColor, ImageDraw, ImageFont
11+
from PIL import __version__ as PILLOW_VERSION_STRING, Image, ImageColor, ImageDraw, ImageFont
1212

1313

1414
__all__ = [
15+
"_Image_fromarray",
1516
"make_grid",
1617
"save_image",
1718
"draw_bounding_boxes",
@@ -174,6 +175,85 @@ def dashed_line(self, xy, fill=None, width=0, joint=None, dash_length=5, space_l
174175
current_dash = not current_dash
175176

176177

178+
def _Image_fromarray(
179+
obj: np.ndarray,
180+
mode: str,
181+
) -> Image.Image:
182+
"""
183+
A wrapper around PIL.Image.fromarray to mitigate the deprecation of the
184+
mode paramter. See:
185+
https://pillow.readthedocs.io/en/stable/releasenotes/11.3.0.html#image-fromarray-mode-parameter
186+
"""
187+
188+
# This may throw if the version string is from an install that comes from a
189+
# non-stable or development version. We'll fall back to the old behavior in
190+
# such cases.
191+
try:
192+
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION_STRING.split("."))
193+
except Exception:
194+
PILLOW_VERSION = None
195+
196+
if PILLOW_VERSION is not None and PILLOW_VERSION >= (11, 3):
197+
# The actual PR that implements the deprecation has more context for why
198+
# it was done, and also points out some problems:
199+
#
200+
# https://github.com/python-pillow/Pillow/pull/9018
201+
#
202+
# Our use case falls into those problems. We actually rely on the old
203+
# behavior of Image.fromarray():
204+
#
205+
# new behavior: PIL will infer the image mode from the data passed
206+
# in. That is, the type and shape determines the mode.
207+
#
208+
# old behiavor: The mode will change how PIL reads the image,
209+
# regardless of the data. That is, it will make the
210+
# data work with the mode.
211+
#
212+
# Our uses of Image.fromarray() are effectively a "turn into PIL image
213+
# AND convert the kind" operation. In particular, in
214+
# functional.to_pil_image() and transforms.ToPILImage.
215+
#
216+
# However, Image.frombuffer() still performs this conversion. The code
217+
# below is lifted from the new implementation of Image.fromarray(). We
218+
# omit the code that infers the mode, and use the code that figures out
219+
# from the data passed in (obj) what the correct parameters are to
220+
# Image.frombuffer().
221+
#
222+
# Note that the alternate solution below does not work:
223+
#
224+
# img = Image.fromarray(obj)
225+
# img = img.convert(mode)
226+
#
227+
# The resulting image has very different actual pixel values than before.
228+
#
229+
# TODO: Issue #9151. Pillow has an open PR to restore the functionality
230+
# we rely on:
231+
#
232+
# https://github.com/python-pillow/Pillow/pull/9063
233+
#
234+
# When that is part of a release, we can revisit this hack below.
235+
arr = obj.__array_interface__
236+
shape = arr["shape"]
237+
ndim = len(shape)
238+
size = 1 if ndim == 1 else shape[1], shape[0]
239+
240+
strides = arr.get("strides", None)
241+
contiguous_obj: Union[np.ndarray, bytes] = obj
242+
if strides is not None:
243+
# We require that the data is contiguous; if it is not, we need to
244+
# convert it into a contiguous format.
245+
if hasattr(obj, "tobytes"):
246+
contiguous_obj = obj.tobytes()
247+
elif hasattr(obj, "tostring"):
248+
contiguous_obj = obj.tostring()
249+
else:
250+
raise ValueError("Unable to convert obj into contiguous format")
251+
252+
return Image.frombuffer(mode, size, contiguous_obj, "raw", mode, 0, 1)
253+
else:
254+
return Image.fromarray(obj, mode)
255+
256+
177257
@torch.no_grad()
178258
def save_image(
179259
tensor: Union[torch.Tensor, list[torch.Tensor]],

0 commit comments

Comments
 (0)