Skip to content

Commit 3f2c2df

Browse files
author
pytorchbot
committed
2024-08-28 nightly release (a59c939)
1 parent 81f341d commit 3f2c2df

File tree

9 files changed

+209
-30
lines changed

9 files changed

+209
-30
lines changed

test/test_image.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ def test_decode_gif_webp_errors(decode_fun):
875875
if decode_fun is decode_gif:
876876
expected_match = re.escape("DGifOpenFileName() failed - 103")
877877
elif decode_fun is decode_webp:
878-
expected_match = "WebPDecodeRGB failed."
878+
expected_match = "WebPGetFeatures failed."
879879
with pytest.raises(RuntimeError, match=expected_match):
880880
decode_fun(encoded_data)
881881

@@ -891,6 +891,31 @@ def test_decode_webp(decode_fun, scripted):
891891
assert img[None].is_contiguous(memory_format=torch.channels_last)
892892

893893

894+
# This test is skipped because it requires webp images that we're not including
895+
# within the repo. The test images were downloaded from the different pages of
896+
# https://developers.google.com/speed/webp/gallery
897+
# Note that converting an RGBA image to RGB leads to bad results because the
898+
# transparent pixels aren't necessarily set to "black" or "white", they can be
899+
# random stuff. This is consistent with PIL results.
900+
@pytest.mark.skip(reason="Need to download test images first")
901+
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
902+
@pytest.mark.parametrize("scripted", (False, True))
903+
@pytest.mark.parametrize(
904+
"mode, pil_mode", ((ImageReadMode.RGB, "RGB"), (ImageReadMode.RGB_ALPHA, "RGBA"), (ImageReadMode.UNCHANGED, None))
905+
)
906+
@pytest.mark.parametrize("filename", Path("/home/nicolashug/webp_samples").glob("*.webp"))
907+
def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename):
908+
encoded_bytes = read_file(filename)
909+
if scripted:
910+
decode_fun = torch.jit.script(decode_fun)
911+
img = decode_fun(encoded_bytes, mode=mode)
912+
assert img[None].is_contiguous(memory_format=torch.channels_last)
913+
914+
pil_img = Image.open(filename).convert(pil_mode)
915+
from_pil = F.pil_to_tensor(pil_img)
916+
assert_equal(img, from_pil)
917+
918+
894919
@pytest.mark.xfail(reason="AVIF support not enabled yet.")
895920
@pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
896921
@pytest.mark.parametrize("scripted", (False, True))
@@ -903,5 +928,65 @@ def test_decode_avif(decode_fun, scripted):
903928
assert img[None].is_contiguous(memory_format=torch.channels_last)
904929

905930

931+
@pytest.mark.xfail(reason="AVIF support not enabled yet.")
932+
# Note: decode_image fails because some of these files have a (valid) signature
933+
# we don't recognize. We should probably use libmagic....
934+
# @pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
935+
@pytest.mark.parametrize("decode_fun", (_decode_avif,))
936+
@pytest.mark.parametrize("scripted", (False, True))
937+
@pytest.mark.parametrize(
938+
"mode, pil_mode",
939+
(
940+
(ImageReadMode.RGB, "RGB"),
941+
(ImageReadMode.RGB_ALPHA, "RGBA"),
942+
(ImageReadMode.UNCHANGED, None),
943+
),
944+
)
945+
@pytest.mark.parametrize("filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"))
946+
def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename):
947+
if "reversed_dimg_order" in str(filename):
948+
# Pillow properly decodes this one, but we don't (order of parts of the
949+
# image is wrong). This is due to a bug that was recently fixed in
950+
# libavif. Hopefully this test will end up passing soon with a new
951+
# libavif version https://github.com/AOMediaCodec/libavif/issues/2311
952+
pytest.xfail()
953+
import pillow_avif # noqa
954+
955+
encoded_bytes = read_file(filename)
956+
if scripted:
957+
decode_fun = torch.jit.script(decode_fun)
958+
try:
959+
img = decode_fun(encoded_bytes, mode=mode)
960+
except RuntimeError as e:
961+
if any(
962+
s in str(e)
963+
for s in ("BMFF parsing failed", "avifDecoderParse failed: ", "file contains more than one image")
964+
):
965+
pytest.skip(reason="Expected failure, that's OK")
966+
else:
967+
raise e
968+
assert img[None].is_contiguous(memory_format=torch.channels_last)
969+
if mode == ImageReadMode.RGB:
970+
assert img.shape[0] == 3
971+
if mode == ImageReadMode.RGB_ALPHA:
972+
assert img.shape[0] == 4
973+
if img.dtype == torch.uint16:
974+
img = F.to_dtype(img, dtype=torch.uint8, scale=True)
975+
976+
from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode))
977+
if False:
978+
from torchvision.utils import make_grid
979+
980+
g = make_grid([img, from_pil])
981+
F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png"))
982+
if mode != ImageReadMode.RGB:
983+
# We don't compare against PIL for RGB because results look pretty
984+
# different on RGBA images (other images are fine). The result on
985+
# torchvision basically just plainly ignores the alpha channel, resuting
986+
# in transparent pixels looking dark. PIL seems to be using a sort of
987+
# k-nn thing, looking at the output. Take a look at the resuting images.
988+
torch.testing.assert_close(img, from_pil, rtol=0, atol=3)
989+
990+
906991
if __name__ == "__main__":
907992
pytest.main([__file__])

torchvision/csrc/io/image/cpu/decode_avif.cpp

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ namespace vision {
88
namespace image {
99

1010
#if !AVIF_FOUND
11-
torch::Tensor decode_avif(const torch::Tensor& data) {
11+
torch::Tensor decode_avif(
12+
const torch::Tensor& encoded_data,
13+
ImageReadMode mode) {
1214
TORCH_CHECK(
1315
false, "decode_avif: torchvision not compiled with libavif support");
1416
}
@@ -23,7 +25,9 @@ struct UniquePtrDeleter {
2325
};
2426
using DecoderPtr = std::unique_ptr<avifDecoder, UniquePtrDeleter>;
2527

26-
torch::Tensor decode_avif(const torch::Tensor& encoded_data) {
28+
torch::Tensor decode_avif(
29+
const torch::Tensor& encoded_data,
30+
ImageReadMode mode) {
2731
// This is based on
2832
// https://github.com/AOMediaCodec/libavif/blob/main/examples/avif_example_decode_memory.c
2933
// Refer there for more detail about what each function does, and which
@@ -58,24 +62,43 @@ torch::Tensor decode_avif(const torch::Tensor& encoded_data) {
5862
avifResultToString(result));
5963
TORCH_CHECK(
6064
decoder->imageCount == 1, "Avif file contains more than one image");
61-
TORCH_CHECK(
62-
decoder->image->depth <= 8,
63-
"avif images with bitdepth > 8 are not supported");
6465

6566
result = avifDecoderNextImage(decoder.get());
6667
TORCH_CHECK(
6768
result == AVIF_RESULT_OK,
6869
"avifDecoderNextImage failed:",
6970
avifResultToString(result));
7071

71-
auto out = torch::empty(
72-
{decoder->image->height, decoder->image->width, 3}, torch::kUInt8);
73-
7472
avifRGBImage rgb;
7573
memset(&rgb, 0, sizeof(rgb));
7674
avifRGBImageSetDefaults(&rgb, decoder->image);
77-
rgb.format = AVIF_RGB_FORMAT_RGB;
78-
rgb.pixels = out.data_ptr<uint8_t>();
75+
76+
// images encoded as 10 or 12 bits will be decoded as uint16. The rest are
77+
// decoded as uint8.
78+
auto use_uint8 = (decoder->image->depth <= 8);
79+
rgb.depth = use_uint8 ? 8 : 16;
80+
81+
if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB &&
82+
mode != IMAGE_READ_MODE_RGB_ALPHA) {
83+
// Other modes aren't supported, but we don't error or even warn because we
84+
// have generic entry points like decode_image which may support all modes,
85+
// it just depends on the underlying decoder.
86+
mode = IMAGE_READ_MODE_UNCHANGED;
87+
}
88+
89+
// If return_rgb is false it means we return rgba - nothing else.
90+
auto return_rgb =
91+
(mode == IMAGE_READ_MODE_RGB ||
92+
(mode == IMAGE_READ_MODE_UNCHANGED && !decoder->alphaPresent));
93+
94+
auto num_channels = return_rgb ? 3 : 4;
95+
rgb.format = return_rgb ? AVIF_RGB_FORMAT_RGB : AVIF_RGB_FORMAT_RGBA;
96+
rgb.ignoreAlpha = return_rgb ? AVIF_TRUE : AVIF_FALSE;
97+
98+
auto out = torch::empty(
99+
{rgb.height, rgb.width, num_channels},
100+
use_uint8 ? torch::kUInt8 : at::kUInt16);
101+
rgb.pixels = (uint8_t*)out.data_ptr();
79102
rgb.rowBytes = rgb.width * avifRGBImagePixelSize(&rgb);
80103

81104
result = avifImageYUVToRGB(decoder->image, &rgb);
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
#pragma once
22

33
#include <torch/types.h>
4+
#include "../image_read_mode.h"
45

56
namespace vision {
67
namespace image {
78

8-
C10_EXPORT torch::Tensor decode_avif(const torch::Tensor& data);
9+
C10_EXPORT torch::Tensor decode_avif(
10+
const torch::Tensor& encoded_data,
11+
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
912

1013
} // namespace image
1114
} // namespace vision

torchvision/csrc/io/image/cpu/decode_image.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ torch::Tensor decode_image(
5858
0x66, 0x74, 0x79, 0x70, 0x61, 0x76, 0x69, 0x66}; // == "ftypavif"
5959
TORCH_CHECK(data.numel() >= 12, err_msg);
6060
if ((memcmp(avif_signature, datap + 4, 8) == 0)) {
61-
return decode_avif(data);
61+
return decode_avif(data, mode);
6262
}
6363

6464
const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF"
@@ -67,7 +67,7 @@ torch::Tensor decode_image(
6767
TORCH_CHECK(data.numel() >= 15, err_msg);
6868
if ((memcmp(webp_signature_begin, datap, 4) == 0) &&
6969
(memcmp(webp_signature_end, datap + 8, 7) == 0)) {
70-
return decode_webp(data);
70+
return decode_webp(data, mode);
7171
}
7272

7373
TORCH_CHECK(false, err_msg);

torchvision/csrc/io/image/cpu/decode_webp.cpp

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,17 @@ namespace vision {
88
namespace image {
99

1010
#if !WEBP_FOUND
11-
torch::Tensor decode_webp(const torch::Tensor& data) {
11+
torch::Tensor decode_webp(
12+
const torch::Tensor& encoded_data,
13+
ImageReadMode mode) {
1214
TORCH_CHECK(
1315
false, "decode_webp: torchvision not compiled with libwebp support");
1416
}
1517
#else
1618

17-
torch::Tensor decode_webp(const torch::Tensor& encoded_data) {
19+
torch::Tensor decode_webp(
20+
const torch::Tensor& encoded_data,
21+
ImageReadMode mode) {
1822
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
1923
TORCH_CHECK(
2024
encoded_data.dtype() == torch::kU8,
@@ -26,13 +30,43 @@ torch::Tensor decode_webp(const torch::Tensor& encoded_data) {
2630
encoded_data.dim(),
2731
" dims.");
2832

33+
auto encoded_data_p = encoded_data.data_ptr<uint8_t>();
34+
auto encoded_data_size = encoded_data.numel();
35+
36+
WebPBitstreamFeatures features;
37+
auto res = WebPGetFeatures(encoded_data_p, encoded_data_size, &features);
38+
TORCH_CHECK(
39+
res == VP8_STATUS_OK, "WebPGetFeatures failed with error code ", res);
40+
TORCH_CHECK(
41+
!features.has_animation, "Animated webp files are not supported.");
42+
43+
if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB &&
44+
mode != IMAGE_READ_MODE_RGB_ALPHA) {
45+
// Other modes aren't supported, but we don't error or even warn because we
46+
// have generic entry points like decode_image which may support all modes,
47+
// it just depends on the underlying decoder.
48+
mode = IMAGE_READ_MODE_UNCHANGED;
49+
}
50+
51+
// If return_rgb is false it means we return rgba - nothing else.
52+
auto return_rgb =
53+
(mode == IMAGE_READ_MODE_RGB ||
54+
(mode == IMAGE_READ_MODE_UNCHANGED && !features.has_alpha));
55+
56+
auto decoding_func = return_rgb ? WebPDecodeRGB : WebPDecodeRGBA;
57+
auto num_channels = return_rgb ? 3 : 4;
58+
2959
int width = 0;
3060
int height = 0;
31-
auto decoded_data = WebPDecodeRGB(
32-
encoded_data.data_ptr<uint8_t>(), encoded_data.numel(), &width, &height);
33-
TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB failed.");
34-
auto out = torch::from_blob(decoded_data, {height, width, 3}, torch::kUInt8);
35-
return out.permute({2, 0, 1}); // return CHW, channels-last
61+
62+
auto decoded_data =
63+
decoding_func(encoded_data_p, encoded_data_size, &width, &height);
64+
TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB[A] failed.");
65+
66+
auto out = torch::from_blob(
67+
decoded_data, {height, width, num_channels}, torch::kUInt8);
68+
69+
return out.permute({2, 0, 1});
3670
}
3771
#endif // WEBP_FOUND
3872

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
#pragma once
22

33
#include <torch/types.h>
4+
#include "../image_read_mode.h"
45

56
namespace vision {
67
namespace image {
78

8-
C10_EXPORT torch::Tensor decode_webp(const torch::Tensor& data);
9+
C10_EXPORT torch::Tensor decode_webp(
10+
const torch::Tensor& encoded_data,
11+
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
912

1013
} // namespace image
1114
} // namespace vision

torchvision/csrc/io/image/image.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ static auto registry =
2121
.op("image::encode_png", &encode_png)
2222
.op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
2323
&decode_jpeg)
24-
.op("image::decode_webp", &decode_webp)
25-
.op("image::decode_avif", &decode_avif)
24+
.op("image::decode_webp(Tensor encoded_data, int mode) -> Tensor",
25+
&decode_webp)
26+
.op("image::decode_avif(Tensor encoded_data, int mode) -> Tensor",
27+
&decode_avif)
2628
.op("image::encode_jpeg", &encode_jpeg)
2729
.op("image::read_file", &read_file)
2830
.op("image::write_file", &write_file)

torchvision/io/image.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ class ImageReadMode(Enum):
2828
``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
2929
``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
3030
RGB with transparency.
31+
32+
.. note::
33+
34+
Some decoders won't support all possible values, e.g. a decoder may only
35+
support "RGB" and "RGBA" mode.
3136
"""
3237

3338
UNCHANGED = 0
@@ -365,28 +370,52 @@ def decode_gif(input: torch.Tensor) -> torch.Tensor:
365370

366371
def decode_webp(
367372
input: torch.Tensor,
373+
mode: ImageReadMode = ImageReadMode.UNCHANGED,
368374
) -> torch.Tensor:
369375
"""
370-
Decode a WEBP image into a 3 dimensional RGB Tensor.
376+
Decode a WEBP image into a 3 dimensional RGB[A] Tensor.
371377
372-
The values of the output tensor are uint8 between 0 and 255. If the input
373-
image is RGBA, the transparency is ignored.
378+
The values of the output tensor are uint8 between 0 and 255.
374379
375380
Args:
376381
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
377382
the raw bytes of the WEBP image.
383+
mode (ImageReadMode): The read mode used for optionally
384+
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
385+
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
378386
379387
Returns:
380388
Decoded image (Tensor[image_channels, image_height, image_width])
381389
"""
382390
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
383391
_log_api_usage_once(decode_webp)
384-
return torch.ops.image.decode_webp(input)
392+
return torch.ops.image.decode_webp(input, mode.value)
385393

386394

387395
def _decode_avif(
388396
input: torch.Tensor,
397+
mode: ImageReadMode = ImageReadMode.UNCHANGED,
389398
) -> torch.Tensor:
399+
"""
400+
Decode an AVIF image into a 3 dimensional RGB[A] Tensor.
401+
402+
The values of the output tensor are in uint8 in [0, 255] for most images. If
403+
the image has a bit-depth of more than 8, then the output tensor is uint16
404+
in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
405+
calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
406+
``scale=True`` after this function to convert the decoded image into a uint8
407+
or float tensor.
408+
409+
Args:
410+
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
411+
the raw bytes of the AVIF image.
412+
mode (ImageReadMode): The read mode used for optionally
413+
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
414+
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
415+
416+
Returns:
417+
Decoded image (Tensor[image_channels, image_height, image_width])
418+
"""
390419
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
391420
_log_api_usage_once(decode_webp)
392-
return torch.ops.image.decode_avif(input)
421+
return torch.ops.image.decode_avif(input, mode.value)

torchvision/transforms/v2/functional/_color.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int
6666

6767

6868
def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor:
69-
"""See :class:`~torchvision.transforms.v2.GrayscaleToRgb` for details."""
69+
"""See :class:`~torchvision.transforms.v2.RGB` for details."""
7070
if torch.jit.is_scripting():
7171
return grayscale_to_rgb_image(inpt)
7272

0 commit comments

Comments
 (0)