Skip to content

Commit 4cb5c17

Browse files
authored
[BLAS] Fix cublas perf (#169)
* [BLAS] Fix cublas perf * applying clang-format
1 parent f516ad1 commit 4cb5c17

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

src/blas/backends/cublas/cublas_scope_handle.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ CublasScopedContextHandler::CublasScopedContextHandler(cl::sycl::queue queue,
3838
cl::sycl::interop_handler &ih)
3939
: ih(ih),
4040
needToRecover_(false) {
41-
placedContext_ = queue.get_context();
41+
placedContext_ = new cl::sycl::context(queue.get_context());
4242
auto device = queue.get_device();
43-
auto desired = cl::sycl::get_native<cl::sycl::backend::cuda>(placedContext_);
43+
auto desired = cl::sycl::get_native<cl::sycl::backend::cuda>(*placedContext_);
4444
CUresult err;
4545
CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original_);
4646
if (original_ != desired) {
@@ -61,6 +61,7 @@ CublasScopedContextHandler::~CublasScopedContextHandler() noexcept(false) {
6161
CUresult err;
6262
CUDA_ERROR_FUNC(cuCtxSetCurrent, err, original_);
6363
}
64+
delete placedContext_;
6465
}
6566

6667
void ContextCallback(void *userData) {
@@ -83,8 +84,8 @@ void ContextCallback(void *userData) {
8384
}
8485

8586
cublasHandle_t CublasScopedContextHandler::get_handle(const cl::sycl::queue &queue) {
86-
auto piPlacedContext_ =
87-
reinterpret_cast<pi_context>(cl::sycl::get_native<cl::sycl::backend::cuda>(placedContext_));
87+
auto piPlacedContext_ = reinterpret_cast<pi_context>(
88+
cl::sycl::get_native<cl::sycl::backend::cuda>(*placedContext_));
8889
CUstream streamId = get_stream(queue);
8990
cublasStatus_t err;
9091
auto it = handle_helper.cublas_handle_mapper_.find(piPlacedContext_);
@@ -116,7 +117,7 @@ cublasHandle_t CublasScopedContextHandler::get_handle(const cl::sycl::queue &que
116117
auto insert_iter = handle_helper.cublas_handle_mapper_.insert(
117118
std::make_pair(piPlacedContext_, new std::atomic<cublasHandle_t>(handle)));
118119

119-
sycl::detail::pi::contextSetExtendedDeleter(placedContext_, ContextCallback,
120+
sycl::detail::pi::contextSetExtendedDeleter(*placedContext_, ContextCallback,
120121
insert_iter.first->second);
121122

122123
return handle;

src/blas/backends/cublas/cublas_scope_handle.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ the handle must be destroyed when the context goes out of scope. This will bind
6262

6363
class CublasScopedContextHandler {
6464
CUcontext original_;
65-
cl::sycl::context placedContext_;
65+
cl::sycl::context *placedContext_;
6666
bool needToRecover_;
6767
cl::sycl::interop_handler &ih;
6868
static thread_local cublas_handle<pi_context> handle_helper;

0 commit comments

Comments
 (0)