Skip to content

Commit 425a923

Browse files
committed
Add util for Image.fromarray to deal with mode deprecation
1 parent b079a96 commit 425a923

File tree

7 files changed

+38
-22
lines changed

7 files changed

+38
-22
lines changed

test/common_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,15 @@
1212
from subprocess import CalledProcessError, check_output, STDOUT
1313

1414
import numpy as np
15-
import PIL.Image
1615
import pytest
1716
import torch
1817
import torch.testing
19-
from PIL import Image
2018

2119
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
2220
from torchvision import io, tv_tensors
2321
from torchvision.transforms._functional_tensor import _max_value as get_max_value
2422
from torchvision.transforms.v2.functional import to_image, to_pil_image
23+
from torchvision.utils import _Image_fromarray
2524

2625

2726
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
@@ -147,7 +146,7 @@ def _create_data(height=3, width=3, channels=3, device="cpu"):
147146
if channels == 1:
148147
mode = "L"
149148
data = data[..., 0]
150-
pil_img = Image.fromarray(data, mode=mode)
149+
pil_img = _Image_fromarray(data, mode=mode)
151150
return tensor, pil_img
152151

153152

test/test_transforms.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torchvision.transforms.functional as F
1414
from PIL import Image
1515
from torch._utils_internal import get_file_path_2
16+
from torchvision.utils import _Image_fromarray
1617

1718
try:
1819
import accimage
@@ -654,7 +655,7 @@ def test_1_channel_float_tensor_to_pil_image(self):
654655
img_F_mode = transforms.ToPILImage(mode="F")(img_data)
655656
assert img_F_mode.mode == "F"
656657
torch.testing.assert_close(
657-
np.array(Image.fromarray(img_data.squeeze(0).numpy(), mode="F")), np.array(img_F_mode)
658+
np.array(_Image_fromarray(img_data.squeeze(0).numpy(), mode="F")), np.array(img_F_mode)
658659
)
659660

660661
@pytest.mark.parametrize("with_mode", [False, True])
@@ -895,7 +896,7 @@ def test_adjust_brightness():
895896
x_shape = [2, 2, 3]
896897
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
897898
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
898-
x_pil = Image.fromarray(x_np, mode="RGB")
899+
x_pil = _Image_fromarray(x_np, mode="RGB")
899900

900901
# test 0
901902
y_pil = F.adjust_brightness(x_pil, 1)
@@ -921,7 +922,7 @@ def test_adjust_contrast():
921922
x_shape = [2, 2, 3]
922923
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
923924
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
924-
x_pil = Image.fromarray(x_np, mode="RGB")
925+
x_pil = _Image_fromarray(x_np, mode="RGB")
925926

926927
# test 0
927928
y_pil = F.adjust_contrast(x_pil, 1)
@@ -947,7 +948,7 @@ def test_adjust_hue():
947948
x_shape = [2, 2, 3]
948949
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
949950
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
950-
x_pil = Image.fromarray(x_np, mode="RGB")
951+
x_pil = _Image_fromarray(x_np, mode="RGB")
951952

952953
with pytest.raises(ValueError):
953954
F.adjust_hue(x_pil, -0.7)
@@ -1029,7 +1030,7 @@ def test_adjust_sharpness():
10291030
117,
10301031
]
10311032
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1032-
x_pil = Image.fromarray(x_np, mode="RGB")
1033+
x_pil = _Image_fromarray(x_np, mode="RGB")
10331034

10341035
# test 0
10351036
y_pil = F.adjust_sharpness(x_pil, 1)
@@ -1152,7 +1153,7 @@ def test_adjust_sharpness():
11521153
x_shape = [2, 2, 3]
11531154
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
11541155
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1155-
x_pil = Image.fromarray(x_np, mode="RGB")
1156+
x_pil = _Image_fromarray(x_np, mode="RGB")
11561157
x_th = torch.tensor(x_np.transpose(2, 0, 1))
11571158
y_pil = F.adjust_sharpness(x_pil, 2)
11581159
y_np = np.array(y_pil).transpose(2, 0, 1)
@@ -1164,7 +1165,7 @@ def test_adjust_gamma():
11641165
x_shape = [2, 2, 3]
11651166
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
11661167
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1167-
x_pil = Image.fromarray(x_np, mode="RGB")
1168+
x_pil = _Image_fromarray(x_np, mode="RGB")
11681169

11691170
# test 0
11701171
y_pil = F.adjust_gamma(x_pil, 1)
@@ -1190,7 +1191,7 @@ def test_adjusts_L_mode():
11901191
x_shape = [2, 2, 3]
11911192
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
11921193
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1193-
x_rgb = Image.fromarray(x_np, mode="RGB")
1194+
x_rgb = _Image_fromarray(x_np, mode="RGB")
11941195

11951196
x_l = x_rgb.convert("L")
11961197
assert F.adjust_brightness(x_l, 2).mode == "L"
@@ -1320,7 +1321,7 @@ def test_to_grayscale():
13201321
x_shape = [2, 2, 3]
13211322
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
13221323
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1323-
x_pil = Image.fromarray(x_np, mode="RGB")
1324+
x_pil = _Image_fromarray(x_np, mode="RGB")
13241325
x_pil_2 = x_pil.convert("L")
13251326
gray_np = np.array(x_pil_2)
13261327

@@ -1769,7 +1770,7 @@ def test_color_jitter():
17691770
x_shape = [2, 2, 3]
17701771
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
17711772
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1772-
x_pil = Image.fromarray(x_np, mode="RGB")
1773+
x_pil = _Image_fromarray(x_np, mode="RGB")
17731774
x_pil_2 = x_pil.convert("L")
17741775

17751776
for _ in range(10):

torchvision/datasets/mnist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
import numpy as np
1313
import torch
14-
from PIL import Image
1514

15+
from ..utils import _Image_fromarray
1616
from .utils import _flip_byte_order, check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
1717
from .vision import VisionDataset
1818

@@ -140,7 +140,7 @@ def __getitem__(self, index: int) -> tuple[Any, Any]:
140140

141141
# doing this so that it is consistent with all other datasets
142142
# to return a PIL Image
143-
img = Image.fromarray(img.numpy(), mode="L")
143+
img = _Image_fromarray(img.numpy(), mode="L")
144144

145145
if self.transform is not None:
146146
img = self.transform(img)
@@ -478,7 +478,7 @@ def download(self) -> None:
478478
def __getitem__(self, index: int) -> tuple[Any, Any]:
479479
# redefined to handle the compat flag
480480
img, target = self.data[index], self.targets[index]
481-
img = Image.fromarray(img.numpy(), mode="L")
481+
img = _Image_fromarray(img.numpy(), mode="L")
482482
if self.transform is not None:
483483
img = self.transform(img)
484484
if self.compat:

torchvision/datasets/semeion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from typing import Any, Callable, Optional, Union
44

55
import numpy as np
6-
from PIL import Image
76

7+
from ..utils import _Image_fromarray
88
from .utils import check_integrity, download_url
99
from .vision import VisionDataset
1010

@@ -64,7 +64,7 @@ def __getitem__(self, index: int) -> tuple[Any, Any]:
6464

6565
# doing this so that it is consistent with all other datasets
6666
# to return a PIL Image
67-
img = Image.fromarray(img, mode="L")
67+
img = _Image_fromarray(img, mode="L")
6868

6969
if self.transform is not None:
7070
img = self.transform(img)

torchvision/datasets/usps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from typing import Any, Callable, Optional, Union
44

55
import numpy as np
6-
from PIL import Image
76

7+
from ..utils import _Image_fromarray
88
from .utils import download_url
99
from .vision import VisionDataset
1010

@@ -82,7 +82,7 @@ def __getitem__(self, index: int) -> tuple[Any, Any]:
8282

8383
# doing this so that it is consistent with all other datasets
8484
# to return a PIL Image
85-
img = Image.fromarray(img, mode="L")
85+
img = _Image_fromarray(img, mode="L")
8686

8787
if self.transform is not None:
8888
img = self.transform(img)

torchvision/transforms/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
except ImportError:
1717
accimage = None
1818

19-
from ..utils import _log_api_usage_once
19+
from ..utils import _Image_fromarray, _log_api_usage_once
2020
from . import _functional_pil as F_pil, _functional_tensor as F_t
2121

2222

@@ -321,7 +321,7 @@ def to_pil_image(pic, mode=None):
321321
if mode is None:
322322
raise TypeError(f"Input type {npimg.dtype} is not supported")
323323

324-
return Image.fromarray(npimg, mode=mode)
324+
return _Image_fromarray(npimg, mode=mode)
325325

326326

327327
def normalize(tensor: Tensor, mean: list[float], std: list[float], inplace: bool = False) -> Tensor:

torchvision/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010
import torch
11+
from packaging import version
1112
from PIL import Image, ImageColor, ImageDraw, ImageFont
1213

1314

@@ -174,6 +175,21 @@ def dashed_line(self, xy, fill=None, width=0, joint=None, dash_length=5, space_l
174175
current_dash = not current_dash
175176

176177

178+
def _Image_fromarray(
179+
obj: Union[torch.Tensor, np.ndarray],
180+
mode: Optional[str],
181+
) -> Image:
182+
"""
183+
A wrapper around PIL.Image.fromarray to mitigate the deprecation of the
184+
mode paramter. See:
185+
https://pillow.readthedocs.io/en/stable/releasenotes/11.3.0.html#image-fromarray-mode-parameter
186+
"""
187+
if version.parse(Image.__version__) >= version.parse("11.3.0"):
188+
return Image.fromarray(obj)
189+
else:
190+
return Image.fromarray(obj, mode)
191+
192+
177193
@torch.no_grad()
178194
def save_image(
179195
tensor: Union[torch.Tensor, list[torch.Tensor]],

0 commit comments

Comments
 (0)