Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,61 @@ def test_decode_bad_huffman_images():
decode_jpeg(bad_huff)


def test_encode_jpeg_privateuseone_custom_backend():
privateuseone_name = torch._C._get_privateuse1_backend_name()
device = torch.device(privateuseone_name)

data = torch.randint(0, 256, size=(3, 4, 5), dtype=torch.uint8, device=device)

with pytest.raises(RuntimeError, match="encode_jpegs_privateuseone"):
encode_jpeg(data)

lib = torch.library.Library("image", "FRAGMENT")
called = {}

try:
lib.define("encode_jpegs_privateuseone(Tensor[] input, int quality=75) -> Tensor[]")

@torch.library.impl(lib, "encode_jpegs_privateuseone", "PrivateUse1")
def _encode_jpegs_privateuseone(input, quality=75):
called["value"] = True
return input
except RuntimeError:
pass
encoded = encode_jpeg(data)
assert called.get("value") is True
torch.testing.assert_close(encoded, data.cpu())


def test_decode_jpeg_privateuseone_custom_backend():
privateuseone_name = torch._C._get_privateuse1_backend_name()
device = torch.device(privateuseone_name)
data = torch.full((1, 2, 3), 233, dtype=torch.uint8)
# When the custom operator is not registered, an error should
# be reported and prompted to register decode_jpegs_privateuseone.
with pytest.raises(RuntimeError, match="decode_jpegs_privateuseone"):
decode_jpeg(data, device=device)

# Register a simple custom implementation to return the original data
called = {}
lib = torch.library.Library("image", "FRAGMENT")
try:
lib.define(
"decode_jpegs_privateuseone(Tensor[] input, int mode=0, bool apply_exif_orientation=False) -> Tensor[]"
)

@torch.library.impl(lib, "decode_jpegs_privateuseone", "CPU")
def _decode_jpegs_privateuseone(input, mode=0, apply_exif_orientation=False):
called["value"] = True
return input
except RuntimeError:
pass

output = decode_jpeg(data, device=device)
assert called.get("value") is True
torch.testing.assert_close(output, data)


@pytest.mark.parametrize(
"img_path",
[
Expand Down
44 changes: 42 additions & 2 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,17 @@ def decode_jpeg(
raise ValueError("All elements of the input list must be tensors.")
if not all(t.device.type == "cpu" for t in input):
raise ValueError("Input list must contain tensors on CPU.")
custom_privateuse1_name = torch._C._get_privateuse1_backend_name()
if device.type == custom_privateuse1_name or device.type == "privateuseone":
# When the target device is privateuseone, switch to calling the custom decode_jpegs_privateuseone.
# This operator needs to be pre-registered by the user through torch.library.define/impl.
decoder = getattr(torch.ops.image, "decode_jpegs_privateuseone", None)
if decoder is None:
raise RuntimeError(
"decode_jpeg tensors on PrivateUse1 device require registering "
"torch.ops.image.decode_jpegs_privateuseone."
)
return decoder(input, mode.value, apply_exif_orientation)
if device.type == "cuda":
return torch.ops.image.decode_jpegs_cuda(input, mode.value, device)
else:
Expand All @@ -218,6 +229,15 @@ def decode_jpeg(
else: # input is tensor
if input.device.type != "cpu":
raise ValueError("Input tensor must be a CPU tensor")
if device.type == custom_privateuse1_name or device.type == "privateuseone":
custom_privateuse1_name = torch._C._get_privateuse1_backend_name()
decoder = getattr(torch.ops.image, "decode_jpegs_privateuseone", None)
if decoder is None:
raise RuntimeError(
"decode_jpeg tensor on PrivateUse1 device require registering "
"torch.ops.image.decode_jpegs_privateuseone."
)
return decoder([input], mode.value, apply_exif_orientation)[0]
if device.type == "cuda":
return torch.ops.image.decode_jpegs_cuda([input], mode.value, device)[0]
else:
Expand Down Expand Up @@ -246,16 +266,36 @@ def encode_jpeg(
_log_api_usage_once(encode_jpeg)
if quality < 1 or quality > 100:
raise ValueError("Image quality should be a positive number between 1 and 100")
custom_privateuse1_name = torch._C._get_privateuse1_backend_name()

if isinstance(input, list):
if not input:
raise ValueError("encode_jpeg requires at least one input tensor when a list is passed")
if input[0].device.type == "cuda":
device_type = input[0].device.type
if device_type == custom_privateuse1_name or device_type == "privateuseone":
encoder = getattr(torch.ops.image, "encode_jpegs_privateuseone", None)
if encoder is None:
raise RuntimeError(
"encode_jpeg tensors on PrivateUse1 device require registering "
"torch.ops.image.encode_jpegs_privateuseone."
)
return encoder(input, quality)
if device_type == "cuda":
return torch.ops.image.encode_jpegs_cuda(input, quality)
else:
return [torch.ops.image.encode_jpeg(image, quality) for image in input]
else: # single input tensor
if input.device.type == "cuda":
device_type = input.device.type
if device_type == "cuda":
return torch.ops.image.encode_jpegs_cuda([input], quality)[0]
elif device_type == custom_privateuse1_name or device_type == "privateuseone":
encoder = getattr(torch.ops.image, "encode_jpegs_privateuseone", None)
if encoder is None:
raise RuntimeError(
"encode_jpeg tensor on PrivateUse1 device require registering "
"torch.ops.image.encode_jpegs_privateuseone."
)
return encoder([input], quality)[0]
else:
return torch.ops.image.encode_jpeg(input, quality)

Expand Down