|
4 | 4 | import os |
5 | 5 | import re |
6 | 6 | import sys |
| 7 | +from contextlib import nullcontext |
7 | 8 | from pathlib import Path |
8 | 9 |
|
9 | 10 | import numpy as np |
|
13 | 14 | import torchvision.transforms.v2.functional as F |
14 | 15 | from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda |
15 | 16 | from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence |
| 17 | +from torchvision._internally_replaced_utils import IN_FBCODE |
16 | 18 | from torchvision.io.image import ( |
17 | 19 | _decode_avif, |
18 | 20 | _decode_heic, |
@@ -1044,5 +1046,45 @@ def test_decode_heic(decode_fun, scripted): |
1044 | 1046 | img += 123 # make sure image buffer wasn't freed by underlying decoding lib |
1045 | 1047 |
|
1046 | 1048 |
|
| 1049 | +@pytest.mark.parametrize("input_type", ("Path", "str", "tensor")) |
| 1050 | +@pytest.mark.parametrize("scripted", (False, True)) |
| 1051 | +def test_decode_image_path(input_type, scripted): |
| 1052 | + # Check that decode_image can support not just tensors as input |
| 1053 | + path = next(get_images(IMAGE_ROOT, ".jpg")) |
| 1054 | + if input_type == "Path": |
| 1055 | + input = Path(path) |
| 1056 | + elif input_type == "str": |
| 1057 | + input = path |
| 1058 | + elif input_type == "tensor": |
| 1059 | + input = read_file(path) |
| 1060 | + else: |
| 1061 | + raise ValueError("Oops") |
| 1062 | + |
| 1063 | + if scripted and input_type == "Path": |
| 1064 | + pytest.xfail(reason="Can't pass a Path when scripting") |
| 1065 | + |
| 1066 | + decode_fun = torch.jit.script(decode_image) if scripted else decode_image |
| 1067 | + decode_fun(input) |
| 1068 | + |
| 1069 | + |
| 1070 | +def test_mode_str(): |
| 1071 | + # Make sure decode_image supports string modes. We just test decode_image, |
| 1072 | + # not all of the decoding functions, but they should all support that too. |
| 1073 | + # Torchscript fails when passing strings, which is expected. |
| 1074 | + path = next(get_images(IMAGE_ROOT, ".png")) |
| 1075 | + assert decode_image(path, mode="RGB").shape[0] == 3 |
| 1076 | + assert decode_image(path, mode="rGb").shape[0] == 3 |
| 1077 | + assert decode_image(path, mode="GRAY").shape[0] == 1 |
| 1078 | + assert decode_image(path, mode="RGBA").shape[0] == 4 |
| 1079 | + |
| 1080 | + |
| 1081 | +def test_avif_heic_fbcode(): |
| 1082 | + cm = nullcontext() if IN_FBCODE else pytest.raises(ImportError, match="cannot import") |
| 1083 | + with cm: |
| 1084 | + from torchvision.io import decode_heic # noqa |
| 1085 | + with cm: |
| 1086 | + from torchvision.io import decode_avif # noqa |
| 1087 | + |
| 1088 | + |
1047 | 1089 | if __name__ == "__main__": |
1048 | 1090 | pytest.main([__file__]) |
0 commit comments