Skip to content

Commit f2c684d

Browse files
committed
encode_jpeg cuda sync bug as test
1 parent 867521e commit f2c684d

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

test/test_image.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,45 @@ def test_encode_jpeg_cuda(img_path, scripted, contiguous):
622622
abs_mean_diff = (decoded_jpeg_cuda_tv.float() - decoded_image_tv.float()).abs().mean().item()
623623
assert abs_mean_diff < 3
624624

625+
@needs_cuda
626+
def test_encode_jpeg_cuda_sync():
627+
"""
628+
Attempts to reproduce an intermittent CUDA stream synchronization bug
629+
by randomly creating small images and round-tripping them via encode_jpeg
630+
and decode_jpeg on the GPU. Fails if the mean difference in u8 range exceeds 5.0.
631+
632+
https://github.com/pytorch/vision/issues/8587
633+
"""
634+
np.random.seed(42)
635+
torch.manual_seed(42)
636+
637+
# manual testing shows this bug appearing often in iterations between 50 and 100
638+
# as a synchronization bug, this can't be reliably reproduced
639+
max_iterations = 200
640+
threshold = 5.0 # in [0..255]
641+
642+
device = torch.device("cuda")
643+
644+
for iteration in range(max_iterations):
645+
# Randomly pick a small square image size in [1..64]
646+
size = np.random.randint(1, 65)
647+
height, width = size, size
648+
649+
image = torch.linspace(0, 1, steps=height * width, device=device)
650+
image = image.view(1, height, width).expand(3, -1, -1)
651+
652+
image_uint8 = (image * 255).clamp(0, 255).to(torch.uint8)
653+
jpeg_bytes = encode_jpeg(image_uint8, quality=100)
654+
655+
decoded_image = decode_jpeg(jpeg_bytes.cpu(), device=device).float() / 255.0
656+
mean_difference = (image - decoded_image).abs().mean().item() * 255
657+
658+
if mean_difference > threshold:
659+
pytest.fail(
660+
f"Encode/decode mismatch at iteration={iteration}, "
661+
f"size={height}x{width}, mean diff={mean_difference:.2f}"
662+
)
663+
625664

626665
@pytest.mark.parametrize("device", cpu_and_cuda())
627666
@pytest.mark.parametrize("scripted", (True, False))

0 commit comments

Comments
 (0)