@@ -38,9 +38,9 @@ CublasScopedContextHandler::CublasScopedContextHandler(cl::sycl::queue queue,
3838 cl::sycl::interop_handler &ih)
3939 : ih(ih),
4040 needToRecover_ (false ) {
41- placedContext_ = queue.get_context ();
41+ placedContext_ = new cl::sycl::context ( queue.get_context () );
4242 auto device = queue.get_device ();
43- auto desired = cl::sycl::get_native<cl::sycl::backend::cuda>(placedContext_);
43+ auto desired = cl::sycl::get_native<cl::sycl::backend::cuda>(* placedContext_);
4444 CUresult err;
4545 CUDA_ERROR_FUNC (cuCtxGetCurrent, err, &original_);
4646 if (original_ != desired) {
@@ -61,6 +61,7 @@ CublasScopedContextHandler::~CublasScopedContextHandler() noexcept(false) {
6161 CUresult err;
6262 CUDA_ERROR_FUNC (cuCtxSetCurrent, err, original_);
6363 }
64+ delete placedContext_;
6465}
6566
6667void ContextCallback (void *userData) {
@@ -83,8 +84,8 @@ void ContextCallback(void *userData) {
8384}
8485
8586cublasHandle_t CublasScopedContextHandler::get_handle (const cl::sycl::queue &queue) {
86- auto piPlacedContext_ =
87- reinterpret_cast <pi_context>( cl::sycl::get_native<cl::sycl::backend::cuda>(placedContext_));
87+ auto piPlacedContext_ = reinterpret_cast <pi_context>(
88+ cl::sycl::get_native<cl::sycl::backend::cuda>(* placedContext_));
8889 CUstream streamId = get_stream (queue);
8990 cublasStatus_t err;
9091 auto it = handle_helper.cublas_handle_mapper_ .find (piPlacedContext_);
@@ -116,7 +117,7 @@ cublasHandle_t CublasScopedContextHandler::get_handle(const cl::sycl::queue &que
116117 auto insert_iter = handle_helper.cublas_handle_mapper_ .insert (
117118 std::make_pair (piPlacedContext_, new std::atomic<cublasHandle_t>(handle)));
118119
119- sycl::detail::pi::contextSetExtendedDeleter (placedContext_, ContextCallback,
120+ sycl::detail::pi::contextSetExtendedDeleter (* placedContext_, ContextCallback,
120121 insert_iter.first ->second );
121122
122123 return handle;
0 commit comments