@@ -377,6 +377,14 @@ def decode_webp(
377377 return torch .ops .image .decode_webp (input , mode .value )
378378
379379
380+ # TODO_AVIF_HEIC: Better support for torchscript. Scripting decode_avif of
381+ # decode_heic currently fails, mainly because of the logic
382+ # _load_extra_decoders_once() (using global variables, try/except statements,
383+ # etc.).
384+ # The ops (torch.ops.extra_decoders_ns.decode_*) are otherwise torchscript-able,
385+ # and users who need torchscript can always just wrap those.
386+
387+
380388_EXTRA_DECODERS_ALREADY_LOADED = False
381389
382390
@@ -388,9 +396,23 @@ def _load_extra_decoders_once():
388396 try :
389397 import torchvision_extra_decoders
390398
399+ # torchvision-extra-decoders only supports linux for now. BUT, users on
400+ # e.g. MacOS can still install it: they will get the pure-python
401+ # 0.0.0.dev version:
402+ # https://pypi.org/project/torchvision-extra-decoders/0.0.0.dev0, which
403+ # is a dummy version that was created to reserve the namespace on PyPI.
404+ # We have to check that expose_extra_decoders() exists for those users,
405+ # so we can properly error on non-Linux archs.
391406 assert hasattr (torchvision_extra_decoders , "expose_extra_decoders" )
392407 except (AssertionError , ImportError ) as e :
393- raise RuntimeError ("You need to pip install torchvision-extra-decoders blah blah blah" ) from e
408+ raise RuntimeError (
409+ "In order to enable the AVIF and HEIC decoding capabilities of "
410+ "torchvision, you need to `pip install torchvision-extra-decoders`. "
411+ "Just install the package, you don't need to update your code. "
412+ "This is only supported on Linux, and this feature is still in BETA stage. "
413+ "Please let us know of any issue: https://github.com/pytorch/vision/issues/new/choose. "
414+ "Note that `torchvision-extra-decoders` is released under the LGPL license. "
415+ ) from e
394416
395417 # This will expose torch.ops.extra_decoders_ns.decode_avif and torch.ops.extra_decoders_ns.decode_heic
396418 torchvision_extra_decoders .expose_extra_decoders ()
@@ -399,13 +421,51 @@ def _load_extra_decoders_once():
399421
400422
401423def decode_avif (input : torch .Tensor , mode : ImageReadMode = ImageReadMode .UNCHANGED ) -> torch .Tensor :
424+ """Decode an AVIF image into a 3 dimensional RGB[A] Tensor.
425+
426+ The values of the output tensor are in uint8 in [0, 255] for most images. If
427+ the image has a bit-depth of more than 8, then the output tensor is uint16
428+ in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
429+ calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
430+ ``scale=True`` after this function to convert the decoded image into a uint8
431+ or float tensor.
432+
433+ Args:
434+ input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
435+ the raw bytes of the AVIF image.
436+ mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
437+ Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
438+ for available modes.
439+
440+ Returns:
441+ Decoded image (Tensor[image_channels, image_height, image_width])
442+ """
402443 _load_extra_decoders_once ()
403444 if input .dtype != torch .uint8 :
404445 raise RuntimeError (f"Input tensor must have uint8 data type, got { input .dtype } " )
405446 return torch .ops .extra_decoders_ns .decode_avif (input , mode .value )
406447
407448
408449def decode_heic (input : torch .Tensor , mode : ImageReadMode = ImageReadMode .UNCHANGED ) -> torch .Tensor :
450+ """Decode an HEIC image into a 3 dimensional RGB[A] Tensor.
451+
452+ The values of the output tensor are in uint8 in [0, 255] for most images. If
453+ the image has a bit-depth of more than 8, then the output tensor is uint16
454+ in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
455+ calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
456+ ``scale=True`` after this function to convert the decoded image into a uint8
457+ or float tensor.
458+
459+ Args:
460+ input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
461+ the raw bytes of the HEIC image.
462+ mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
463+ Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
464+ for available modes.
465+
466+ Returns:
467+ Decoded image (Tensor[image_channels, image_height, image_width])
468+ """
409469 _load_extra_decoders_once ()
410470 if input .dtype != torch .uint8 :
411471 raise RuntimeError (f"Input tensor must have uint8 data type, got { input .dtype } " )
0 commit comments