Skip to content

Commit 7df3ee3

Browse files
fix bug
1 parent 813622d commit 7df3ee3

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

test/test_image.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,7 @@ def test_encode_jpeg_cuda(img_path, scripted, contiguous):
622622
abs_mean_diff = (decoded_jpeg_cuda_tv.float() - decoded_image_tv.float()).abs().mean().item()
623623
assert abs_mean_diff < 3
624624

625+
625626
@needs_cuda
626627
def test_encode_jpeg_cuda_sync():
627628
"""
@@ -636,14 +637,13 @@ def test_encode_jpeg_cuda_sync():
636637

637638
# manual testing shows this bug appearing often in iterations between 50 and 100
638639
# as a synchronization bug, this can't be reliably reproduced
639-
max_iterations = 200
640+
max_iterations = 100
640641
threshold = 5.0 # in [0..255]
641642

642643
device = torch.device("cuda")
643644

644645
for iteration in range(max_iterations):
645-
# Randomly pick a small square image size in [1..64]
646-
size = np.random.randint(1, 65)
646+
size = np.random.randint(4000, 5000)
647647
height, width = size, size
648648

649649
image = torch.linspace(0, 1, steps=height * width, device=device)

torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,7 @@ std::vector<torch::Tensor> encode_jpegs_cuda(
108108
// do not want to block the host at this particular point
109109
// (which is what cudaStreamSynchronize would do.) Events allow us to
110110
// synchronize the streams without blocking the host.
111-
event.block(at::cuda::getCurrentCUDAStream(
112-
cudaJpegEncoder->original_device.has_index()
113-
? cudaJpegEncoder->original_device.index()
114-
: 0));
111+
event.block(cudaJpegEncoder->current_stream);
115112
return encoded_images;
116113
}
117114

@@ -121,7 +118,11 @@ CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device)
121118
stream{
122119
target_device.has_index()
123120
? at::cuda::getStreamFromPool(false, target_device.index())
124-
: at::cuda::getStreamFromPool(false)} {
121+
: at::cuda::getStreamFromPool(false)},
122+
current_stream{
123+
original_device.has_index()
124+
? at::cuda::getCurrentCUDAStream(original_device.index())
125+
: at::cuda::getCurrentCUDAStream()} {
125126
nvjpegStatus_t status;
126127
status = nvjpegCreateSimple(&nvjpeg_handle);
127128
TORCH_CHECK(
@@ -186,12 +187,17 @@ CUDAJpegEncoder::~CUDAJpegEncoder() {
186187
}
187188

188189
torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) {
190+
nvjpegStatus_t status;
191+
cudaError_t cudaStatus;
192+
193+
// Ensure that the incoming src_image is safe to use
194+
cudaStatus = cudaStreamSynchronize(current_stream);
195+
TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus);
196+
189197
int channels = src_image.size(0);
190198
int height = src_image.size(1);
191199
int width = src_image.size(2);
192200

193-
nvjpegStatus_t status;
194-
cudaError_t cudaStatus;
195201
status = nvjpegEncoderParamsSetSamplingFactors(
196202
nv_enc_params, NVJPEG_CSS_444, stream);
197203
TORCH_CHECK(

torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class CUDAJpegEncoder {
2222
const torch::Device original_device;
2323
const torch::Device target_device;
2424
const c10::cuda::CUDAStream stream;
25+
const c10::cuda::CUDAStream current_stream;
2526

2627
protected:
2728
nvjpegEncoderState_t nv_enc_state;

0 commit comments

Comments
 (0)