Skip to content

Commit b6b863e

Browse files
committed
Fix, not sure why
1 parent cc76037 commit b6b863e

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torchvision/io/image.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,9 +398,13 @@ def _load_extra_decoders_once():
398398

399399
def decode_avif(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
400400
_load_extra_decoders_once()
401+
if input.dtype != torch.uint8:
402+
raise RuntimeError(f"Input tensor must have uint8 data type, got {input.dtype}")
401403
return torch.ops.extra_decoders_ns.decode_avif(input, mode.value)
402404

403405

404406
def decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
405407
_load_extra_decoders_once()
408+
if input.dtype != torch.uint8:
409+
raise RuntimeError(f"Input tensor must have uint8 data type, got {input.dtype}")
406410
return torch.ops.extra_decoders_ns.decode_heic(input, mode.value)

0 commit comments

Comments
 (0)