diff --git a/test/test_image.py b/test/test_image.py index b11dd67ca12..9bac78f2397 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -160,6 +160,61 @@ def test_decode_bad_huffman_images(): decode_jpeg(bad_huff) +def test_encode_jpeg_privateuseone_custom_backend(): + privateuseone_name = torch._C._get_privateuse1_backend_name() + device = torch.device(privateuseone_name) + + data = torch.randint(0, 256, size=(3, 4, 5), dtype=torch.uint8, device=device) + + with pytest.raises(RuntimeError, match="encode_jpegs_privateuseone"): + encode_jpeg(data) + + lib = torch.library.Library("image", "FRAGMENT") + called = {} + + try: + lib.define("encode_jpegs_privateuseone(Tensor[] input, int quality=75) -> Tensor[]") + + @torch.library.impl(lib, "encode_jpegs_privateuseone", "PrivateUse1") + def _encode_jpegs_privateuseone(input, quality=75): + called["value"] = True + return input + except RuntimeError: + pass + encoded = encode_jpeg(data) + assert called.get("value") is True + torch.testing.assert_close(encoded, data.cpu()) + + +def test_decode_jpeg_privateuseone_custom_backend(): + privateuseone_name = torch._C._get_privateuse1_backend_name() + device = torch.device(privateuseone_name) + data = torch.full((1, 2, 3), 233, dtype=torch.uint8) + # When the custom operator is not registered, an error should + # be reported and prompted to register decode_jpegs_privateuseone. + with pytest.raises(RuntimeError, match="decode_jpegs_privateuseone"): + decode_jpeg(data, device=device) + + # Register a simple custom implementation to return the original data + called = {} + lib = torch.library.Library("image", "FRAGMENT") + try: + lib.define( + "decode_jpegs_privateuseone(Tensor[] input, int mode=0, bool apply_exif_orientation=False) -> Tensor[]" + ) + + @torch.library.impl(lib, "decode_jpegs_privateuseone", "CPU") + def _decode_jpegs_privateuseone(input, mode=0, apply_exif_orientation=False): + called["value"] = True + return input + except RuntimeError: + pass + + output = decode_jpeg(data, device=device) + assert called.get("value") is True + torch.testing.assert_close(output, data) + + @pytest.mark.parametrize( "img_path", [ diff --git a/torchvision/io/image.py b/torchvision/io/image.py index c88e58ca4ca..85e4c4973f1 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -210,6 +210,17 @@ def decode_jpeg( raise ValueError("All elements of the input list must be tensors.") if not all(t.device.type == "cpu" for t in input): raise ValueError("Input list must contain tensors on CPU.") + custom_privateuse1_name = torch._C._get_privateuse1_backend_name() + if device.type == custom_privateuse1_name or device.type == "privateuseone": + # When the target device is privateuseone, switch to calling the custom decode_jpegs_privateuseone. + # This operator needs to be pre-registered by the user through torch.library.define/impl. + decoder = getattr(torch.ops.image, "decode_jpegs_privateuseone", None) + if decoder is None: + raise RuntimeError( + "decode_jpeg tensors on PrivateUse1 device require registering " + "torch.ops.image.decode_jpegs_privateuseone." + ) + return decoder(input, mode.value, apply_exif_orientation) if device.type == "cuda": return torch.ops.image.decode_jpegs_cuda(input, mode.value, device) else: @@ -218,6 +229,15 @@ def decode_jpeg( else: # input is tensor if input.device.type != "cpu": raise ValueError("Input tensor must be a CPU tensor") + if device.type == custom_privateuse1_name or device.type == "privateuseone": + custom_privateuse1_name = torch._C._get_privateuse1_backend_name() + decoder = getattr(torch.ops.image, "decode_jpegs_privateuseone", None) + if decoder is None: + raise RuntimeError( + "decode_jpeg tensor on PrivateUse1 device require registering " + "torch.ops.image.decode_jpegs_privateuseone." + ) + return decoder([input], mode.value, apply_exif_orientation)[0] if device.type == "cuda": return torch.ops.image.decode_jpegs_cuda([input], mode.value, device)[0] else: @@ -246,16 +266,36 @@ def encode_jpeg( _log_api_usage_once(encode_jpeg) if quality < 1 or quality > 100: raise ValueError("Image quality should be a positive number between 1 and 100") + custom_privateuse1_name = torch._C._get_privateuse1_backend_name() + if isinstance(input, list): if not input: raise ValueError("encode_jpeg requires at least one input tensor when a list is passed") - if input[0].device.type == "cuda": + device_type = input[0].device.type + if device_type == custom_privateuse1_name or device_type == "privateuseone": + encoder = getattr(torch.ops.image, "encode_jpegs_privateuseone", None) + if encoder is None: + raise RuntimeError( + "encode_jpeg tensors on PrivateUse1 device require registering " + "torch.ops.image.encode_jpegs_privateuseone." + ) + return encoder(input, quality) + if device_type == "cuda": return torch.ops.image.encode_jpegs_cuda(input, quality) else: return [torch.ops.image.encode_jpeg(image, quality) for image in input] else: # single input tensor - if input.device.type == "cuda": + device_type = input.device.type + if device_type == "cuda": return torch.ops.image.encode_jpegs_cuda([input], quality)[0] + elif device_type == custom_privateuse1_name or device_type == "privateuseone": + encoder = getattr(torch.ops.image, "encode_jpegs_privateuseone", None) + if encoder is None: + raise RuntimeError( + "encode_jpeg tensor on PrivateUse1 device require registering " + "torch.ops.image.encode_jpegs_privateuseone." + ) + return encoder([input], quality)[0] else: return torch.ops.image.encode_jpeg(input, quality)