Skip to content

Commit 08a8df3

Browse files
committed
Allow decode_image to support paths
1 parent c36025a commit 08a8df3

File tree

3 files changed

+38
-31
lines changed

3 files changed

+38
-31
lines changed

docs/source/io.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ For encoding, JPEG (cpu and CUDA) and PNG are supported.
1919
:toctree: generated/
2020
:template: function.rst
2121

22-
read_image
2322
decode_image
2423
encode_jpeg
2524
decode_jpeg
@@ -38,6 +37,13 @@ For encoding, JPEG (cpu and CUDA) and PNG are supported.
3837

3938
ImageReadMode
4039

40+
Obsolete decoding function:
41+
42+
.. autosummary::
43+
:toctree: generated/
44+
:template: class.rst
45+
46+
read_image
4147

4248

4349
Video

test/test_image.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,5 +1044,26 @@ def test_decode_heic(decode_fun, scripted):
10441044
img += 123 # make sure image buffer wasn't freed by underlying decoding lib
10451045

10461046

1047+
@pytest.mark.parametrize("input_type", ("Path", "str", "tensor"))
1048+
@pytest.mark.parametrize("scripted", (False, True))
1049+
def test_decode_image_path(input_type, scripted):
1050+
# Check that decode_image can support not just tensors as input
1051+
path = next(get_images(IMAGE_ROOT, ".jpg"))
1052+
if input_type == "Path":
1053+
input = Path(path)
1054+
elif input_type == "str":
1055+
input = path
1056+
elif input_type == "tensor":
1057+
input = read_file(path)
1058+
else:
1059+
raise ValueError("Oops")
1060+
1061+
if scripted and input_type == "Path":
1062+
pytest.xfail(reason="Can't pass a Path when scripting")
1063+
1064+
decode_fun = torch.jit.script(decode_image) if scripted else decode_image
1065+
decode_fun(input)
1066+
1067+
10471068
if __name__ == "__main__":
10481069
pytest.main([__file__])

torchvision/io/image.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,13 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
277277

278278

279279
def decode_image(
280-
input: torch.Tensor,
280+
input: Union[torch.Tensor, str],
281281
mode: ImageReadMode = ImageReadMode.UNCHANGED,
282282
apply_exif_orientation: bool = False,
283283
) -> torch.Tensor:
284-
"""
285-
Detect whether an image is a JPEG, PNG, WEBP, or GIF and performs the
286-
appropriate operation to decode the image into a Tensor.
284+
"""Decode an image into a tensor.
285+
286+
Currently supported image formats are jpeg, png, gif and webp.
287287
288288
The values of the output tensor are in uint8 in [0, 255] for most cases.
289289
@@ -295,8 +295,9 @@ def decode_image(
295295
tensor.
296296
297297
Args:
298-
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
299-
image.
298+
input (Tensor or str or ``pathlib.Path``): The image to decode. If a
299+
tensor is passed, it must be one dimensional uint8 tensor containing
300+
the raw bytes of the image. Otherwise, this must be a path to the image file.
300301
mode (ImageReadMode): the read mode used for optionally converting the image.
301302
Default: ``ImageReadMode.UNCHANGED``.
302303
See ``ImageReadMode`` class for more information on various
@@ -309,6 +310,8 @@ def decode_image(
309310
"""
310311
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
311312
_log_api_usage_once(decode_image)
313+
if not isinstance(input, torch.Tensor):
314+
input = read_file(str(input))
312315
output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
313316
return output
314317

@@ -318,30 +321,7 @@ def read_image(
318321
mode: ImageReadMode = ImageReadMode.UNCHANGED,
319322
apply_exif_orientation: bool = False,
320323
) -> torch.Tensor:
321-
"""
322-
Reads a JPEG, PNG, WEBP, or GIF image into a Tensor.
323-
324-
The values of the output tensor are in uint8 in [0, 255] for most cases.
325-
326-
If the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
327-
(supported from torchvision ``0.21``. Since uint16 support is limited in
328-
pytorch, we recommend calling
329-
:func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
330-
after this function to convert the decoded image into a uint8 or float
331-
tensor.
332-
333-
Args:
334-
path (str or ``pathlib.Path``): path of the image.
335-
mode (ImageReadMode): the read mode used for optionally converting the image.
336-
Default: ``ImageReadMode.UNCHANGED``.
337-
See ``ImageReadMode`` class for more information on various
338-
available modes. Only applies to JPEG and PNG images.
339-
apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
340-
Only applies to JPEG and PNG images. Default: False.
341-
342-
Returns:
343-
output (Tensor[image_channels, image_height, image_width])
344-
"""
324+
"""[OBSOLETE] Use :func:`~torchvision.io.decode_image` instead."""
345325
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
346326
_log_api_usage_once(read_image)
347327
data = read_file(path)

0 commit comments

Comments
 (0)