Skip to content

Commit 01297e1

Browse files
committed
Comments etc.
1 parent 10ff9fe commit 01297e1

File tree

4 files changed

+67
-4
lines changed

4 files changed

+67
-4
lines changed

docs/source/io.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ encode/decode JPEGs on CUDA.
4141
decode_image
4242
decode_jpeg
4343
encode_png
44-
decode_gif
4544
decode_webp
45+
decode_avif
46+
decode_heic
47+
decode_gif
4648

4749
.. autosummary::
4850
:toctree: generated/

test/smoke_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def smoke_test_torchvision_read_decode() -> None:
3939
# support CUDA.
4040
# Strangely, on the CPU runners where this fails, the AVIF/HEIC
4141
# tests (ran with pytest) are passing. This is likely related to a
42-
# libcxx symbol thing.
42+
# libcxx symbol thing, and the proper libstdc++.so get loaded only
43+
# with pytest? Ugh.
4344
img_avif = decode_avif(read_file(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.avif")))
4445
if img_avif.shape != (3, 100, 100):
4546
raise RuntimeError(f"Unexpected shape of img_avif: {img_avif.shape}")

torchvision/csrc/io/image/cpu/decode_image.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ torch::Tensor decode_image(
2222
"Expected a non empty 1-dimensional tensor");
2323

2424
auto err_msg =
25-
"Unsupported image file. Only jpeg, png and gif are currently supported.";
25+
"Unsupported image file. Only jpeg, png, webp and gif are currently supported. For avif and heic format, please rely on `decode_avif` and `decode_heic` directly.";
2626

2727
auto datap = data.data_ptr<uint8_t>();
2828

torchvision/io/image.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,14 @@ def decode_webp(
377377
return torch.ops.image.decode_webp(input, mode.value)
378378

379379

380+
# TODO_AVIF_HEIC: Better support for torchscript. Scripting decode_avif of
381+
# decode_heic currently fails, mainly because of the logic
382+
# _load_extra_decoders_once() (using global variables, try/except statements,
383+
# etc.).
384+
# The ops (torch.ops.extra_decoders_ns.decode_*) are otherwise torchscript-able,
385+
# and users who need torchscript can always just wrap those.
386+
387+
380388
_EXTRA_DECODERS_ALREADY_LOADED = False
381389

382390

@@ -388,9 +396,23 @@ def _load_extra_decoders_once():
388396
try:
389397
import torchvision_extra_decoders
390398

399+
# torchvision-extra-decoders only supports linux for now. BUT, users on
400+
# e.g. MacOS can still install it: they will get the pure-python
401+
# 0.0.0.dev version:
402+
# https://pypi.org/project/torchvision-extra-decoders/0.0.0.dev0, which
403+
# is a dummy version that was created to reserve the namespace on PyPI.
404+
# We have to check that expose_extra_decoders() exists for those users,
405+
# so we can properly error on non-Linux archs.
391406
assert hasattr(torchvision_extra_decoders, "expose_extra_decoders")
392407
except (AssertionError, ImportError) as e:
393-
raise RuntimeError("You need to pip install torchvision-extra-decoders blah blah blah") from e
408+
raise RuntimeError(
409+
"In order to enable the AVIF and HEIC decoding capabilities of "
410+
"torchvision, you need to `pip install torchvision-extra-decoders`. "
411+
"Just install the package, you don't need to update your code. "
412+
"This is only supported on Linux, and this feature is still in BETA stage. "
413+
"Please let us know of any issue: https://github.com/pytorch/vision/issues/new/choose. "
414+
"Note that `torchvision-extra-decoders` is released under the LGPL license. "
415+
) from e
394416

395417
# This will expose torch.ops.extra_decoders_ns.decode_avif and torch.ops.extra_decoders_ns.decode_heic
396418
torchvision_extra_decoders.expose_extra_decoders()
@@ -399,13 +421,51 @@ def _load_extra_decoders_once():
399421

400422

401423
def decode_avif(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
424+
"""Decode an AVIF image into a 3 dimensional RGB[A] Tensor.
425+
426+
The values of the output tensor are in uint8 in [0, 255] for most images. If
427+
the image has a bit-depth of more than 8, then the output tensor is uint16
428+
in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
429+
calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
430+
``scale=True`` after this function to convert the decoded image into a uint8
431+
or float tensor.
432+
433+
Args:
434+
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
435+
the raw bytes of the AVIF image.
436+
mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
437+
Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
438+
for available modes.
439+
440+
Returns:
441+
Decoded image (Tensor[image_channels, image_height, image_width])
442+
"""
402443
_load_extra_decoders_once()
403444
if input.dtype != torch.uint8:
404445
raise RuntimeError(f"Input tensor must have uint8 data type, got {input.dtype}")
405446
return torch.ops.extra_decoders_ns.decode_avif(input, mode.value)
406447

407448

408449
def decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
450+
"""Decode an HEIC image into a 3 dimensional RGB[A] Tensor.
451+
452+
The values of the output tensor are in uint8 in [0, 255] for most images. If
453+
the image has a bit-depth of more than 8, then the output tensor is uint16
454+
in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
455+
calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
456+
``scale=True`` after this function to convert the decoded image into a uint8
457+
or float tensor.
458+
459+
Args:
460+
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
461+
the raw bytes of the HEIC image.
462+
mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
463+
Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
464+
for available modes.
465+
466+
Returns:
467+
Decoded image (Tensor[image_channels, image_height, image_width])
468+
"""
409469
_load_extra_decoders_once()
410470
if input.dtype != torch.uint8:
411471
raise RuntimeError(f"Input tensor must have uint8 data type, got {input.dtype}")

0 commit comments

Comments
 (0)