Skip to content

Commit e2d571c

Browse files
committed
Slightly simplify test
1 parent 7df3ee3 commit e2d571c

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

test/test_image.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -629,10 +629,8 @@ def test_encode_jpeg_cuda_sync():
629629
Attempts to reproduce an intermittent CUDA stream synchronization bug
630630
by randomly creating small images and round-tripping them via encode_jpeg
631631
and decode_jpeg on the GPU. Fails if the mean difference in u8 range exceeds 5.0.
632-
633632
https://github.com/pytorch/vision/issues/8587
634633
"""
635-
np.random.seed(42)
636634
torch.manual_seed(42)
637635

638636
# manual testing shows this bug appearing often in iterations between 50 and 100
@@ -643,23 +641,22 @@ def test_encode_jpeg_cuda_sync():
643641
device = torch.device("cuda")
644642

645643
for iteration in range(max_iterations):
646-
size = np.random.randint(4000, 5000)
647-
height, width = size, size
644+
height, width = torch.randint(4000, 5000, size=(2,))
648645

649646
image = torch.linspace(0, 1, steps=height * width, device=device)
650647
image = image.view(1, height, width).expand(3, -1, -1)
651648

652-
image_uint8 = (image * 255).clamp(0, 255).to(torch.uint8)
653-
jpeg_bytes = encode_jpeg(image_uint8, quality=100)
649+
image = (image * 255).clamp(0, 255).to(torch.uint8)
650+
jpeg_bytes = encode_jpeg(image, quality=100)
654651

655-
decoded_image = decode_jpeg(jpeg_bytes.cpu(), device=device).float() / 255.0
656-
mean_difference = (image - decoded_image).abs().mean().item() * 255
652+
decoded_image = decode_jpeg(jpeg_bytes.cpu(), device=device)
653+
mean_difference = (image.float() - decoded_image.float()).abs().mean().item()
654+
print(mean_difference)
657655

658-
if mean_difference > threshold:
659-
pytest.fail(
660-
f"Encode/decode mismatch at iteration={iteration}, "
661-
f"size={height}x{width}, mean diff={mean_difference:.2f}"
662-
)
656+
assert mean_difference <= threshold, (
657+
f"Encode/decode mismatch at iteration={iteration}, "
658+
f"size={height}x{width}, mean diff={mean_difference:.2f}"
659+
)
663660

664661

665662
@pytest.mark.parametrize("device", cpu_and_cuda())

0 commit comments

Comments
 (0)