@@ -94,12 +94,12 @@ std::vector<torch::Tensor> encode_jpegs_cuda(
94
94
95
95
cudaJpegEncoder->set_quality (quality);
96
96
std::vector<torch::Tensor> encoded_images;
97
- at::cuda::CUDAEvent event;
98
- event.record (cudaJpegEncoder->stream );
99
97
for (const auto & image : contig_images) {
100
98
auto encoded_image = cudaJpegEncoder->encode_jpeg (image);
101
99
encoded_images.push_back (encoded_image);
102
100
}
101
+ at::cuda::CUDAEvent event;
102
+ event.record (cudaJpegEncoder->stream );
103
103
104
104
// We use a dedicated stream to do the encoding and even though the results
105
105
// may be ready on that stream we cannot assume that they are also available
@@ -108,10 +108,7 @@ std::vector<torch::Tensor> encode_jpegs_cuda(
108
108
// do not want to block the host at this particular point
109
109
// (which is what cudaStreamSynchronize would do.) Events allow us to
110
110
// synchronize the streams without blocking the host.
111
- event.block (at::cuda::getCurrentCUDAStream (
112
- cudaJpegEncoder->original_device .has_index ()
113
- ? cudaJpegEncoder->original_device .index ()
114
- : 0 ));
111
+ event.block (cudaJpegEncoder->current_stream );
115
112
return encoded_images;
116
113
}
117
114
@@ -121,7 +118,11 @@ CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device)
121
118
stream{
122
119
target_device.has_index ()
123
120
? at::cuda::getStreamFromPool (false , target_device.index ())
124
- : at::cuda::getStreamFromPool (false )} {
121
+ : at::cuda::getStreamFromPool (false )},
122
+ current_stream{
123
+ original_device.has_index ()
124
+ ? at::cuda::getCurrentCUDAStream (original_device.index ())
125
+ : at::cuda::getCurrentCUDAStream ()} {
125
126
nvjpegStatus_t status;
126
127
status = nvjpegCreateSimple (&nvjpeg_handle);
127
128
TORCH_CHECK (
@@ -186,12 +187,17 @@ CUDAJpegEncoder::~CUDAJpegEncoder() {
186
187
}
187
188
188
189
torch::Tensor CUDAJpegEncoder::encode_jpeg (const torch::Tensor& src_image) {
190
+ nvjpegStatus_t status;
191
+ cudaError_t cudaStatus;
192
+
193
+ // Ensure that the incoming src_image is safe to use
194
+ cudaStatus = cudaStreamSynchronize (current_stream);
195
+ TORCH_CHECK (cudaStatus == cudaSuccess, " CUDA ERROR: " , cudaStatus);
196
+
189
197
int channels = src_image.size (0 );
190
198
int height = src_image.size (1 );
191
199
int width = src_image.size (2 );
192
200
193
- nvjpegStatus_t status;
194
- cudaError_t cudaStatus;
195
201
status = nvjpegEncoderParamsSetSamplingFactors (
196
202
nv_enc_params, NVJPEG_CSS_444, stream);
197
203
TORCH_CHECK (
@@ -251,7 +257,7 @@ torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) {
251
257
nv_enc_state,
252
258
encoded_image.data_ptr <uint8_t >(),
253
259
&length,
254
- 0 );
260
+ stream );
255
261
TORCH_CHECK (
256
262
status == NVJPEG_STATUS_SUCCESS,
257
263
" Failed to retrieve encoded image: " ,
0 commit comments