Skip to content

Commit 6e3c97c

Browse files
author
Ewan Crawford
committed
[BLAS] SYCL-Graph integration for native-command
In order to support applications calling the library with a sycl queue recording to a SYCL-Graph, check if the `ext_codeplay_enqueue_native_command` command-group is being recorded to a graph object. If so use the native stream recording APIs to add the blas calls as nodes in the graph. In particular this fixes the llama.cpp unit test `MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0)` on CUDA with SYCL-Graph enabled. Previously this would throw an error: ```sh $ GGML_SYCL_DISABLE_GRAPH=0 ./bin/test-backend-ops -b SYCL0 -o MUL_MAT -p type_a=f16,type_b=f32,m=16,n=1,k=256,bs=\\[1,1\\],nr=\\[2 UR CUDA ERROR: Value: 700 Name: CUDA_ERROR_ILLEGAL_ADDRESS Description: an illegal memory access was encountered Function: operator() Source Location: $HOME/dpcpp/unified-runtime/source/adapters/cuda/queue.cpp:154 Native API failed. Native API returns: 2147483646 (UR_RESULT_ERROR_UNKNOWN) Exception caught at file:$HOME/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp, line:3598, func:operator() SYCL error: CHECK_TRY_ERROR((stream)->wait()): Meet error in this line code! in function ggml_backend_sycl_synchronize at $HOME/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp:3598 $HOME/llama.cpp/ggml/src/ggml-sycl/../ggml-sycl/common.hpp:118: SYCL error Could not attach to process. If your uid matches the uid of the target process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try again as the root user. For more details, see /etc/sysctl.d/10-ptrace.conf ptrace: Operation not permitted. No stack. The program is not being run. ```
1 parent 4a51281 commit 6e3c97c

File tree

6 files changed

+233
-87
lines changed

6 files changed

+233
-87
lines changed

src/blas/backends/cublas/cublas_scope_handle.cpp

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,36 +32,80 @@ namespace cublas {
3232
*/
3333
thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{};
3434

35-
CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {}
35+
CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {
36+
// Initialize streamID member to a CUstream associated with the queue `ih`
37+
// has been submitted to.
38+
streamId = ih.get_native_queue<sycl::backend::ext_oneapi_cuda>();
3639

37-
cublasHandle_t CublasScopedContextHandler::get_handle() {
40+
// Initialize the `cublasHandle_t` member `nativeHandle`
3841
CUdevice device = ih.get_native_device<sycl::backend::ext_oneapi_cuda>();
39-
CUstream streamId = get_stream();
40-
cublasStatus_t err;
41-
4242
auto it = handle_helper.cublas_handle_mapper_.find(device);
4343
if (it != handle_helper.cublas_handle_mapper_.end()) {
44-
cublasHandle_t nativeHandle = it->second;
44+
// Use existing handle if one already exists for the device, but update
45+
// the native stream.
46+
nativeHandle = it->second;
4547
cudaStream_t currentStreamId;
48+
cublasStatus_t err;
4649
CUBLAS_ERROR_FUNC(cublasGetStream, err, nativeHandle, &currentStreamId);
4750
if (currentStreamId != streamId) {
4851
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);
4952
}
50-
return nativeHandle;
5153
}
52-
53-
cublasHandle_t nativeHandle;
54-
CUBLAS_ERROR_FUNC(cublasCreate, err, &nativeHandle);
55-
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);
56-
57-
auto insert_iter =
54+
else {
55+
// Create a new handle if one doesn't already exist for the device
56+
cublasStatus_t err;
57+
CUBLAS_ERROR_FUNC(cublasCreate, err, &nativeHandle);
58+
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);
5859
handle_helper.cublas_handle_mapper_.insert(std::make_pair(device, nativeHandle));
60+
}
61+
}
5962

60-
return nativeHandle;
63+
void CublasScopedContextHandler::begin_recording_if_graph() {
64+
// interop_handle graph methods only available from extension version 2
65+
#if SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
66+
if (!ih.ext_codeplay_has_graph()) {
67+
return;
68+
}
69+
70+
CUresult err;
71+
#if CUDA_VERSION >= 12030
72+
// After CUDA 12.3 we can use cuStreamBeginCaptureToGraph to capture
73+
// the stream directly in the native graph, rather than needing to
74+
// instantiate the stream capture as a new graph.
75+
auto graph = ih.ext_codeplay_get_native_graph<sycl::backend::ext_oneapi_cuda>();
76+
CUDA_ERROR_FUNC(cuStreamBeginCaptureToGraph, err, streamId, graph, nullptr, nullptr, 0,
77+
CU_STREAM_CAPTURE_MODE_GLOBAL);
78+
#else
79+
CUDA_ERROR_FUNC(cuStreamBeginCapture, err, streamId, CU_STREAM_CAPTURE_MODE_GLOBAL);
80+
#endif // CUDA_VERSION
81+
#endif // SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
6182
}
6283

63-
CUstream CublasScopedContextHandler::get_stream() {
64-
return ih.get_native_queue<sycl::backend::ext_oneapi_cuda>();
84+
void CublasScopedContextHandler::end_recording_if_graph() {
85+
// interop_handle graph methods only available from extension version 2
86+
#if SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
87+
if (!ih.ext_codeplay_has_graph()) {
88+
return;
89+
}
90+
91+
auto graph = ih.ext_codeplay_get_native_graph<sycl::backend::ext_oneapi_cuda>();
92+
CUresult err;
93+
#if CUDA_VERSION >= 12030
94+
CUDA_ERROR_FUNC(cuStreamEndCapture, err, streamId, &graph);
95+
#else
96+
// cuStreamEndCapture returns a new graph, if we overwrite
97+
// "graph" it won't be picked up by the SYCL runtime, as
98+
// "ext_codeplay_get_native_graph" returns a passed-by-value pointer.
99+
CUgraph recorded_graph;
100+
CUDA_ERROR_FUNC(cuStreamEndCapture, err, streamId, &recorded_graph);
101+
102+
// Add graph to native graph as a child node
103+
// Need to return a node object for the node to be created,
104+
// can't be nullptr.
105+
CUgraphNode node;
106+
CUDA_ERROR_FUNC(cuGraphAddChildGraphNode, err, &node, graph, nullptr, 0, recorded_graph);
107+
#endif // CUDA_VERSION
108+
#endif // SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
65109
}
66110
} // namespace cublas
67111
} // namespace blas

src/blas/backends/cublas/cublas_scope_handle.hpp

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,49 @@ the handle must be destroyed when the context goes out of scope. This will bind
6363
class CublasScopedContextHandler {
6464
sycl::interop_handle& ih;
6565
static thread_local cublas_handle handle_helper;
66-
CUstream get_stream();
66+
cublasHandle_t nativeHandle;
67+
// Cache the native CU stream when the `CublasScopedContextHandler`object
68+
// is constructed. This avoids calling `get_native_queue(ih)` multiple
69+
// times which isn't guaranteed to return the same CUstream handle each
70+
// time. A scenario that causes problems when trying to start/end cuda
71+
// stream recording to a graph.
72+
CUstream streamId;
6773

6874
public:
75+
/**
76+
* @brief Constructor
77+
* @detail Creates the cublasHandle_t by implicitly impose the advice
78+
* given by nvidia for creating a cublas_handle. (e.g. one cuStream per device
79+
* per thread).
80+
*/
6981
CublasScopedContextHandler(sycl::interop_handle& ih);
7082

7183
/**
72-
* @brief get_handle: creates the handle by implicitly impose the advice
73-
* given by nvidia for creating a cublas_handle. (e.g. one cuStream per device
74-
* per thread).
75-
* @return cublasHandle_t a handle to construct cublas routines
76-
*/
77-
cublasHandle_t get_handle();
84+
* @brief Start recording cuBlas calls to a graph.
85+
* @detail Checks if the command-group associated with \p ih is being added
86+
* to a graph, and if so, begin stream recording of the native CUDA stream
87+
* associated with \p queue to the native cuda-graph object.
88+
*/
89+
void begin_recording_if_graph();
90+
91+
/**
92+
* @brief End recording cuBlas calls to a graph.
93+
* @detail Checks if the command-group associated with \p ih is being added
94+
* to a graph, and if so, ends stream recording of the native CUDA stream
95+
* associated with \p queue to the native cuda-graph object. Doing any
96+
* extra work to ensure that stream recorded calls get added as nodes to
97+
* the native graph object associated with \p ih.
98+
* @param queue The sycl queue to end stream recording on native stream
99+
* backing the queue.
100+
*/
101+
void end_recording_if_graph();
102+
103+
/// @brief Query the cuBLAS handle created on construction
104+
/// @return cublasHandle_t a handle to construct cublas routines
105+
cublasHandle_t get_handle() const {
106+
return nativeHandle;
107+
}
108+
78109
// This is a work-around function for reinterpret_casting the memory. This
79110
// will be fixed when SYCL-2020 has been implemented for Pi backend.
80111
template <typename T, typename U>

src/blas/backends/cublas/cublas_task.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ static inline void host_task_internal(H& cgh, F f) {
6161
cgh.host_task([f](sycl::interop_handle ih) {
6262
#endif
6363
auto sc = CublasScopedContextHandler(ih);
64+
sc.begin_recording_if_graph();
6465
f(sc);
66+
sc.end_recording_if_graph();
6567
});
6668
}
6769
#endif

tests/unit_tests/blas/batch/gemm_batch_usm.cpp

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ extern std::vector<sycl::device*> devices;
4848
namespace {
4949

5050
template <typename Ta, typename Tb, typename Tc, typename Ts>
51-
int test(device* dev, oneapi::math::layout layout, int64_t group_count) {
51+
int test(device* dev, oneapi::math::layout layout, int64_t group_count, bool graph_record = false) {
5252
// Catch asynchronous exceptions.
5353
auto exception_handler = [](exception_list exceptions) {
5454
for (std::exception_ptr const& e : exceptions) {
@@ -247,6 +247,13 @@ int test(device* dev, oneapi::math::layout layout, int64_t group_count) {
247247

248248
try {
249249
#ifdef CALL_RT_API
250+
namespace sycl_exp = sycl::ext::oneapi::experimental;
251+
using modifiable_graph = sycl_exp::command_graph<sycl_exp::graph_state::modifiable>;
252+
std::unique_ptr<modifiable_graph> graph;
253+
if (graph_record) {
254+
graph = std::make_unique<modifiable_graph>(main_queue);
255+
graph->begin_recording(main_queue);
256+
}
250257
switch (layout) {
251258
case oneapi::math::layout::col_major:
252259
done = oneapi::math::blas::column_major::gemm_batch(
@@ -262,7 +269,15 @@ int test(device* dev, oneapi::math::layout layout, int64_t group_count) {
262269
break;
263270
default: break;
264271
}
265-
done.wait_and_throw();
272+
273+
if (graph_record) {
274+
graph->end_recording(main_queue);
275+
auto exec_graph = graph->finalize();
276+
main_queue.ext_oneapi_graph(exec_graph).wait_and_throw();
277+
}
278+
else {
279+
done.wait_and_throw();
280+
}
266281
#else
267282
switch (layout) {
268283
case oneapi::math::layout::col_major:
@@ -365,58 +380,65 @@ int test(device* dev, oneapi::math::layout layout, int64_t group_count) {
365380
}
366381

367382
class GemmBatchUsmTests
368-
: public ::testing::TestWithParam<std::tuple<sycl::device*, oneapi::math::layout>> {};
383+
: public ::testing::TestWithParam<std::tuple<sycl::device*, oneapi::math::layout, bool>> {
384+
virtual void SetUp() override {
385+
// Skip test if graph recording variant and device doesn't support sycl_ext_oneapi_graph
386+
if (std::get<2>(GetParam())) {
387+
CHECK_GRAPH_ON_DEVICE(std::get<0>(GetParam()));
388+
}
389+
}
390+
};
369391

370392
TEST_P(GemmBatchUsmTests, RealHalfPrecision) {
371393
EXPECT_TRUEORSKIP((test<sycl::half, sycl::half, sycl::half, sycl::half>(
372-
std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
394+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
373395
}
374396

375397
TEST_P(GemmBatchUsmTests, HalfHalfFloatPrecision) {
376-
EXPECT_TRUEORSKIP((test<sycl::half, sycl::half, float, float>(std::get<0>(GetParam()),
377-
std::get<1>(GetParam()), 5)));
398+
EXPECT_TRUEORSKIP((test<sycl::half, sycl::half, float, float>(
399+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
378400
}
379401

380402
TEST_P(GemmBatchUsmTests, Int8Int8SinglePrecision) {
381-
EXPECT_TRUEORSKIP((test<std::int8_t, std::int8_t, float, float>(std::get<0>(GetParam()),
382-
std::get<1>(GetParam()), 5)));
403+
EXPECT_TRUEORSKIP((test<std::int8_t, std::int8_t, float, float>(
404+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
383405
}
384406

385407
TEST_P(GemmBatchUsmTests, Int8Int8Int32Precision) {
386408
EXPECT_TRUEORSKIP((test<std::int8_t, std::int8_t, std::int32_t, float>(
387-
std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
409+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
388410
}
389411

390412
TEST_P(GemmBatchUsmTests, RealSinglePrecision) {
391-
EXPECT_TRUEORSKIP(
392-
(test<float, float, float, float>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
413+
EXPECT_TRUEORSKIP((test<float, float, float, float>(
414+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
393415
}
394416

395417
TEST_P(GemmBatchUsmTests, RealDoublePrecision) {
396418
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));
397419

398-
EXPECT_TRUEORSKIP((
399-
test<double, double, double, double>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
420+
EXPECT_TRUEORSKIP((test<double, double, double, double>(
421+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
400422
}
401423

402424
TEST_P(GemmBatchUsmTests, ComplexSinglePrecision) {
403425
EXPECT_TRUEORSKIP(
404426
(test<std::complex<float>, std::complex<float>, std::complex<float>, std::complex<float>>(
405-
std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
427+
std::get<0>(GetParam()), std::get<1>(GetParam()), 5, std::get<2>(GetParam()))));
406428
}
407429

408430
TEST_P(GemmBatchUsmTests, ComplexDoublePrecision) {
409431
CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam()));
410432

411-
EXPECT_TRUEORSKIP(
412-
(test<std::complex<double>, std::complex<double>, std::complex<double>,
413-
std::complex<double>>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)));
433+
EXPECT_TRUEORSKIP((test<std::complex<double>, std::complex<double>, std::complex<double>,
434+
std::complex<double>>(std::get<0>(GetParam()), std::get<1>(GetParam()),
435+
5, std::get<2>(GetParam()))));
414436
}
415437

416438
INSTANTIATE_TEST_SUITE_P(GemmBatchUsmTestSuite, GemmBatchUsmTests,
417439
::testing::Combine(testing::ValuesIn(devices),
418440
testing::Values(oneapi::math::layout::col_major,
419-
oneapi::math::layout::row_major)),
420-
::LayoutDeviceNamePrint());
421-
441+
oneapi::math::layout::row_major),
442+
testing::Values(true, false)),
443+
::LayoutGraphDeviceNamePrint());
422444
} // anonymous namespace

0 commit comments

Comments
 (0)