Skip to content

Commit 111aafb

Browse files
dominikkalluskyw-mNicolasHug
authored
Encode jpeg cuda sync (#8929)
Co-authored-by: Wieland Morgenstern <[email protected]> Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 77e95fc commit 111aafb

File tree

3 files changed

+53
-10
lines changed

3 files changed

+53
-10
lines changed

test/test_image.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,42 @@ def test_encode_jpeg_cuda(img_path, scripted, contiguous):
623623
assert abs_mean_diff < 3
624624

625625

626+
@needs_cuda
627+
def test_encode_jpeg_cuda_sync():
628+
"""
629+
Non-regression test for https://github.com/pytorch/vision/issues/8587.
630+
Attempts to reproduce an intermittent CUDA stream synchronization bug
631+
by randomly creating images and round-tripping them via encode_jpeg
632+
and decode_jpeg on the GPU. Fails if the mean difference in uint8 range
633+
exceeds 5.
634+
"""
635+
torch.manual_seed(42)
636+
637+
# manual testing shows this bug appearing often in iterations between 50 and 100
638+
# as a synchronization bug, this can't be reliably reproduced
639+
max_iterations = 100
640+
threshold = 5.0 # in [0..255]
641+
642+
device = torch.device("cuda")
643+
644+
for iteration in range(max_iterations):
645+
height, width = torch.randint(4000, 5000, size=(2,))
646+
647+
image = torch.linspace(0, 1, steps=height * width, device=device)
648+
image = image.view(1, height, width).expand(3, -1, -1)
649+
650+
image = (image * 255).clamp(0, 255).to(torch.uint8)
651+
jpeg_bytes = encode_jpeg(image, quality=100)
652+
653+
decoded_image = decode_jpeg(jpeg_bytes.cpu(), device=device)
654+
mean_difference = (image.float() - decoded_image.float()).abs().mean().item()
655+
656+
assert mean_difference <= threshold, (
657+
f"Encode/decode mismatch at iteration={iteration}, "
658+
f"size={height}x{width}, mean diff={mean_difference:.2f}"
659+
)
660+
661+
626662
@pytest.mark.parametrize("device", cpu_and_cuda())
627663
@pytest.mark.parametrize("scripted", (True, False))
628664
@pytest.mark.parametrize("contiguous", (True, False))

torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@ std::vector<torch::Tensor> encode_jpegs_cuda(
9494

9595
cudaJpegEncoder->set_quality(quality);
9696
std::vector<torch::Tensor> encoded_images;
97-
at::cuda::CUDAEvent event;
98-
event.record(cudaJpegEncoder->stream);
9997
for (const auto& image : contig_images) {
10098
auto encoded_image = cudaJpegEncoder->encode_jpeg(image);
10199
encoded_images.push_back(encoded_image);
102100
}
101+
at::cuda::CUDAEvent event;
102+
event.record(cudaJpegEncoder->stream);
103103

104104
// We use a dedicated stream to do the encoding and even though the results
105105
// may be ready on that stream we cannot assume that they are also available
@@ -108,10 +108,7 @@ std::vector<torch::Tensor> encode_jpegs_cuda(
108108
// do not want to block the host at this particular point
109109
// (which is what cudaStreamSynchronize would do.) Events allow us to
110110
// synchronize the streams without blocking the host.
111-
event.block(at::cuda::getCurrentCUDAStream(
112-
cudaJpegEncoder->original_device.has_index()
113-
? cudaJpegEncoder->original_device.index()
114-
: 0));
111+
event.block(cudaJpegEncoder->current_stream);
115112
return encoded_images;
116113
}
117114

@@ -121,7 +118,11 @@ CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device)
121118
stream{
122119
target_device.has_index()
123120
? at::cuda::getStreamFromPool(false, target_device.index())
124-
: at::cuda::getStreamFromPool(false)} {
121+
: at::cuda::getStreamFromPool(false)},
122+
current_stream{
123+
original_device.has_index()
124+
? at::cuda::getCurrentCUDAStream(original_device.index())
125+
: at::cuda::getCurrentCUDAStream()} {
125126
nvjpegStatus_t status;
126127
status = nvjpegCreateSimple(&nvjpeg_handle);
127128
TORCH_CHECK(
@@ -186,12 +187,17 @@ CUDAJpegEncoder::~CUDAJpegEncoder() {
186187
}
187188

188189
torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) {
190+
nvjpegStatus_t status;
191+
cudaError_t cudaStatus;
192+
193+
// Ensure that the incoming src_image is safe to use
194+
cudaStatus = cudaStreamSynchronize(current_stream);
195+
TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus);
196+
189197
int channels = src_image.size(0);
190198
int height = src_image.size(1);
191199
int width = src_image.size(2);
192200

193-
nvjpegStatus_t status;
194-
cudaError_t cudaStatus;
195201
status = nvjpegEncoderParamsSetSamplingFactors(
196202
nv_enc_params, NVJPEG_CSS_444, stream);
197203
TORCH_CHECK(
@@ -251,7 +257,7 @@ torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) {
251257
nv_enc_state,
252258
encoded_image.data_ptr<uint8_t>(),
253259
&length,
254-
0);
260+
stream);
255261
TORCH_CHECK(
256262
status == NVJPEG_STATUS_SUCCESS,
257263
"Failed to retrieve encoded image: ",

torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class CUDAJpegEncoder {
2222
const torch::Device original_device;
2323
const torch::Device target_device;
2424
const c10::cuda::CUDAStream stream;
25+
const c10::cuda::CUDAStream current_stream;
2526

2627
protected:
2728
nvjpegEncoderState_t nv_enc_state;

0 commit comments

Comments
 (0)