@@ -92,12 +92,12 @@ std::vector<torch::Tensor> encode_jpegs_cuda(
9292
9393 cudaJpegEncoder->set_quality (quality);
9494 std::vector<torch::Tensor> encoded_images;
95- at::cuda::CUDAEvent event;
96- event.record (cudaJpegEncoder->stream );
9795 for (const auto & image : contig_images) {
9896 auto encoded_image = cudaJpegEncoder->encode_jpeg (image);
9997 encoded_images.push_back (encoded_image);
10098 }
99+ at::cuda::CUDAEvent event;
100+ event.record (cudaJpegEncoder->stream );
101101
102102 // We use a dedicated stream to do the encoding and even though the results
103103 // may be ready on that stream we cannot assume that they are also available
@@ -106,10 +106,7 @@ std::vector<torch::Tensor> encode_jpegs_cuda(
106106 // do not want to block the host at this particular point
107107 // (which is what cudaStreamSynchronize would do.) Events allow us to
108108 // synchronize the streams without blocking the host.
109- event.block (at::cuda::getCurrentCUDAStream (
110- cudaJpegEncoder->original_device .has_index ()
111- ? cudaJpegEncoder->original_device .index ()
112- : 0 ));
109+ event.block (cudaJpegEncoder->current_stream );
113110 return encoded_images;
114111}
115112
@@ -119,7 +116,11 @@ CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device)
119116 stream{
120117 target_device.has_index ()
121118 ? at::cuda::getStreamFromPool (false , target_device.index ())
122- : at::cuda::getStreamFromPool (false )} {
119+ : at::cuda::getStreamFromPool (false )},
120+ current_stream{
121+ original_device.has_index ()
122+ ? at::cuda::getCurrentCUDAStream (original_device.index ())
123+ : at::cuda::getCurrentCUDAStream ()} {
123124 nvjpegStatus_t status;
124125 status = nvjpegCreateSimple (&nvjpeg_handle);
125126 TORCH_CHECK (
@@ -184,12 +185,17 @@ CUDAJpegEncoder::~CUDAJpegEncoder() {
184185}
185186
186187torch::Tensor CUDAJpegEncoder::encode_jpeg (const torch::Tensor& src_image) {
188+ nvjpegStatus_t status;
189+ cudaError_t cudaStatus;
190+
191+ // Ensure that the incoming src_image is safe to use
192+ cudaStatus = cudaStreamSynchronize (current_stream);
193+ TORCH_CHECK (cudaStatus == cudaSuccess, " CUDA ERROR: " , cudaStatus);
194+
187195 int channels = src_image.size (0 );
188196 int height = src_image.size (1 );
189197 int width = src_image.size (2 );
190198
191- nvjpegStatus_t status;
192- cudaError_t cudaStatus;
193199 status = nvjpegEncoderParamsSetSamplingFactors (
194200 nv_enc_params, NVJPEG_CSS_444, stream);
195201 TORCH_CHECK (
@@ -249,7 +255,7 @@ torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) {
249255 nv_enc_state,
250256 encoded_image.data_ptr <uint8_t >(),
251257 &length,
252- 0 );
258+ stream );
253259 TORCH_CHECK (
254260 status == NVJPEG_STATUS_SUCCESS,
255261 " Failed to retrieve encoded image: " ,
0 commit comments