@@ -108,10 +108,7 @@ std::vector<torch::Tensor> encode_jpegs_cuda(
108108 // do not want to block the host at this particular point
109109 // (which is what cudaStreamSynchronize would do.) Events allow us to
110110 // synchronize the streams without blocking the host.
111- event.block (at::cuda::getCurrentCUDAStream (
112- cudaJpegEncoder->original_device .has_index ()
113- ? cudaJpegEncoder->original_device .index ()
114- : 0 ));
111+ event.block (cudaJpegEncoder->current_stream );
115112 return encoded_images;
116113}
117114
@@ -121,7 +118,11 @@ CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device)
121118 stream{
122119 target_device.has_index ()
123120 ? at::cuda::getStreamFromPool (false , target_device.index ())
124- : at::cuda::getStreamFromPool (false )} {
121+ : at::cuda::getStreamFromPool (false )},
122+ current_stream{
123+ original_device.has_index ()
124+ ? at::cuda::getCurrentCUDAStream (original_device.index ())
125+ : at::cuda::getCurrentCUDAStream ()} {
125126 nvjpegStatus_t status;
126127 status = nvjpegCreateSimple (&nvjpeg_handle);
127128 TORCH_CHECK (
@@ -186,12 +187,17 @@ CUDAJpegEncoder::~CUDAJpegEncoder() {
186187}
187188
188189torch::Tensor CUDAJpegEncoder::encode_jpeg (const torch::Tensor& src_image) {
190+ nvjpegStatus_t status;
191+ cudaError_t cudaStatus;
192+
193+ // Ensure that the incoming src_image is safe to use
194+ cudaStatus = cudaStreamSynchronize (current_stream);
195+ TORCH_CHECK (cudaStatus == cudaSuccess, " CUDA ERROR: " , cudaStatus);
196+
189197 int channels = src_image.size (0 );
190198 int height = src_image.size (1 );
191199 int width = src_image.size (2 );
192200
193- nvjpegStatus_t status;
194- cudaError_t cudaStatus;
195201 status = nvjpegEncoderParamsSetSamplingFactors (
196202 nv_enc_params, NVJPEG_CSS_444, stream);
197203 TORCH_CHECK (
0 commit comments