@@ -32,36 +32,80 @@ namespace cublas {
3232 */
3333thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{};
3434
35- CublasScopedContextHandler::CublasScopedContextHandler (sycl::interop_handle& ih) : ih(ih) {}
35+ CublasScopedContextHandler::CublasScopedContextHandler (sycl::interop_handle& ih) : ih(ih) {
36+ // Initialize streamID member to a CUstream associated with the queue `ih`
37+ // has been submitted to.
38+ streamId = ih.get_native_queue <sycl::backend::ext_oneapi_cuda>();
3639
37- cublasHandle_t CublasScopedContextHandler::get_handle () {
40+ // Initialize the ` cublasHandle_t` member `nativeHandle`
3841 CUdevice device = ih.get_native_device <sycl::backend::ext_oneapi_cuda>();
39- CUstream streamId = get_stream ();
40- cublasStatus_t err;
41-
4242 auto it = handle_helper.cublas_handle_mapper_ .find (device);
4343 if (it != handle_helper.cublas_handle_mapper_ .end ()) {
44- cublasHandle_t nativeHandle = it->second ;
44+ // Use existing handle if one already exists for the device, but update
45+ // the native stream.
46+ nativeHandle = it->second ;
4547 cudaStream_t currentStreamId;
48+ cublasStatus_t err;
4649 CUBLAS_ERROR_FUNC (cublasGetStream, err, nativeHandle, ¤tStreamId);
4750 if (currentStreamId != streamId) {
4851 CUBLAS_ERROR_FUNC (cublasSetStream, err, nativeHandle, streamId);
4952 }
50- return nativeHandle;
5153 }
52-
53- cublasHandle_t nativeHandle;
54- CUBLAS_ERROR_FUNC (cublasCreate, err, &nativeHandle);
55- CUBLAS_ERROR_FUNC (cublasSetStream, err, nativeHandle, streamId);
56-
57- auto insert_iter =
54+ else {
55+ // Create a new handle if one doesn't already exist for the device
56+ cublasStatus_t err;
57+ CUBLAS_ERROR_FUNC (cublasCreate, err, &nativeHandle);
58+ CUBLAS_ERROR_FUNC (cublasSetStream, err, nativeHandle, streamId);
5859 handle_helper.cublas_handle_mapper_ .insert (std::make_pair (device, nativeHandle));
60+ }
61+ }
5962
60- return nativeHandle;
63+ void CublasScopedContextHandler::begin_recording_if_graph () {
64+ // interop_handle graph methods only available from extension version 2
65+ #if SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
66+ if (!ih.ext_codeplay_has_graph ()) {
67+ return ;
68+ }
69+
70+ CUresult err;
71+ #if CUDA_VERSION >= 12030
72+ // After CUDA 12.3 we can use cuStreamBeginCaptureToGraph to capture
73+ // the stream directly in the native graph, rather than needing to
74+ // instantiate the stream capture as a new graph.
75+ auto graph = ih.ext_codeplay_get_native_graph <sycl::backend::ext_oneapi_cuda>();
76+ CUDA_ERROR_FUNC (cuStreamBeginCaptureToGraph, err, streamId, graph, nullptr , nullptr , 0 ,
77+ CU_STREAM_CAPTURE_MODE_GLOBAL);
78+ #else
79+ CUDA_ERROR_FUNC (cuStreamBeginCapture, err, streamId, CU_STREAM_CAPTURE_MODE_GLOBAL);
80+ #endif // CUDA_VERSION
81+ #endif // SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
6182}
6283
63- CUstream CublasScopedContextHandler::get_stream () {
64- return ih.get_native_queue <sycl::backend::ext_oneapi_cuda>();
84+ void CublasScopedContextHandler::end_recording_if_graph () {
85+ // interop_handle graph methods only available from extension version 2
86+ #if SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
87+ if (!ih.ext_codeplay_has_graph ()) {
88+ return ;
89+ }
90+
91+ auto graph = ih.ext_codeplay_get_native_graph <sycl::backend::ext_oneapi_cuda>();
92+ CUresult err;
93+ #if CUDA_VERSION >= 12030
94+ CUDA_ERROR_FUNC (cuStreamEndCapture, err, streamId, &graph);
95+ #else
96+ // cuStreamEndCapture returns a new graph, if we overwrite
97+ // "graph" it won't be picked up by the SYCL runtime, as
98+ // "ext_codeplay_get_native_graph" returns a passed-by-value pointer.
99+ CUgraph recorded_graph;
100+ CUDA_ERROR_FUNC (cuStreamEndCapture, err, streamId, &recorded_graph);
101+
102+ // Add graph to native graph as a child node
103+ // Need to return a node object for the node to be created,
104+ // can't be nullptr.
105+ CUgraphNode node;
106+ CUDA_ERROR_FUNC (cuGraphAddChildGraphNode, err, &node, graph, nullptr , 0 , recorded_graph);
107+ #endif // CUDA_VERSION
108+ #endif // SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
65109}
66110} // namespace cublas
67111} // namespace blas
0 commit comments