Skip to content

Commit ee7e66b

Browse files
NicolasHugw-m
authored andcommitted
[fbsync] Encode jpeg cuda sync (#8929)
Reviewed By: scotts Differential Revision: D77997051 fbshipit-source-id: 3b0fc22020c8930219f866ccd64679c5095962e4 Co-authored-by: Wieland Morgenstern <[email protected]> Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent f8f26dd commit ee7e66b

File tree

3 files changed

+53
-10
lines changed

3 files changed

+53
-10
lines changed

test/test_image.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,42 @@ def test_encode_jpeg_cuda(img_path, scripted, contiguous):
623623
assert abs_mean_diff < 3
624624

625625

626+
@needs_cuda
627+
def test_encode_jpeg_cuda_sync():
628+
"""
629+
Non-regression test for https://github.com/pytorch/vision/issues/8587.
630+
Attempts to reproduce an intermittent CUDA stream synchronization bug
631+
by randomly creating images and round-tripping them via encode_jpeg
632+
and decode_jpeg on the GPU. Fails if the mean difference in uint8 range
633+
exceeds 5.
634+
"""
635+
torch.manual_seed(42)
636+
637+
# manual testing shows this bug appearing often in iterations between 50 and 100
638+
# as a synchronization bug, this can't be reliably reproduced
639+
max_iterations = 100
640+
threshold = 5.0 # in [0..255]
641+
642+
device = torch.device("cuda")
643+
644+
for iteration in range(max_iterations):
645+
height, width = torch.randint(4000, 5000, size=(2,))
646+
647+
image = torch.linspace(0, 1, steps=height * width, device=device)
648+
image = image.view(1, height, width).expand(3, -1, -1)
649+
650+
image = (image * 255).clamp(0, 255).to(torch.uint8)
651+
jpeg_bytes = encode_jpeg(image, quality=100)
652+
653+
decoded_image = decode_jpeg(jpeg_bytes.cpu(), device=device)
654+
mean_difference = (image.float() - decoded_image.float()).abs().mean().item()
655+
656+
assert mean_difference <= threshold, (
657+
f"Encode/decode mismatch at iteration={iteration}, "
658+
f"size={height}x{width}, mean diff={mean_difference:.2f}"
659+
)
660+
661+
626662
@pytest.mark.parametrize("device", cpu_and_cuda())
627663
@pytest.mark.parametrize("scripted", (True, False))
628664
@pytest.mark.parametrize("contiguous", (True, False))

torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,12 @@ std::vector<torch::Tensor> encode_jpegs_cuda(
9292

9393
cudaJpegEncoder->set_quality(quality);
9494
std::vector<torch::Tensor> encoded_images;
95-
at::cuda::CUDAEvent event;
96-
event.record(cudaJpegEncoder->stream);
9795
for (const auto& image : contig_images) {
9896
auto encoded_image = cudaJpegEncoder->encode_jpeg(image);
9997
encoded_images.push_back(encoded_image);
10098
}
99+
at::cuda::CUDAEvent event;
100+
event.record(cudaJpegEncoder->stream);
101101

102102
// We use a dedicated stream to do the encoding and even though the results
103103
// may be ready on that stream we cannot assume that they are also available
@@ -106,10 +106,7 @@ std::vector<torch::Tensor> encode_jpegs_cuda(
106106
// do not want to block the host at this particular point
107107
// (which is what cudaStreamSynchronize would do.) Events allow us to
108108
// synchronize the streams without blocking the host.
109-
event.block(at::cuda::getCurrentCUDAStream(
110-
cudaJpegEncoder->original_device.has_index()
111-
? cudaJpegEncoder->original_device.index()
112-
: 0));
109+
event.block(cudaJpegEncoder->current_stream);
113110
return encoded_images;
114111
}
115112

@@ -119,7 +116,11 @@ CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device)
119116
stream{
120117
target_device.has_index()
121118
? at::cuda::getStreamFromPool(false, target_device.index())
122-
: at::cuda::getStreamFromPool(false)} {
119+
: at::cuda::getStreamFromPool(false)},
120+
current_stream{
121+
original_device.has_index()
122+
? at::cuda::getCurrentCUDAStream(original_device.index())
123+
: at::cuda::getCurrentCUDAStream()} {
123124
nvjpegStatus_t status;
124125
status = nvjpegCreateSimple(&nvjpeg_handle);
125126
TORCH_CHECK(
@@ -184,12 +185,17 @@ CUDAJpegEncoder::~CUDAJpegEncoder() {
184185
}
185186

186187
torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) {
188+
nvjpegStatus_t status;
189+
cudaError_t cudaStatus;
190+
191+
// Ensure that the incoming src_image is safe to use
192+
cudaStatus = cudaStreamSynchronize(current_stream);
193+
TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus);
194+
187195
int channels = src_image.size(0);
188196
int height = src_image.size(1);
189197
int width = src_image.size(2);
190198

191-
nvjpegStatus_t status;
192-
cudaError_t cudaStatus;
193199
status = nvjpegEncoderParamsSetSamplingFactors(
194200
nv_enc_params, NVJPEG_CSS_444, stream);
195201
TORCH_CHECK(
@@ -249,7 +255,7 @@ torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) {
249255
nv_enc_state,
250256
encoded_image.data_ptr<uint8_t>(),
251257
&length,
252-
0);
258+
stream);
253259
TORCH_CHECK(
254260
status == NVJPEG_STATUS_SUCCESS,
255261
"Failed to retrieve encoded image: ",

torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class CUDAJpegEncoder {
2222
const torch::Device original_device;
2323
const torch::Device target_device;
2424
const c10::cuda::CUDAStream stream;
25+
const c10::cuda::CUDAStream current_stream;
2526

2627
protected:
2728
nvjpegEncoderState_t nv_enc_state;

0 commit comments

Comments
 (0)