Skip to content

Commit 3789aaf

Browse files
authored
[cusolver] Add support for computation_error exceptions (#178)
This patch applies PR#162 to cusolver functions. Moreover, it introduces a check that can throws a `oneapi::mkl::lapack::computation_error`.
1 parent 8012ec4 commit 3789aaf

File tree

3 files changed

+786
-631
lines changed

3 files changed

+786
-631
lines changed

src/lapack/backends/cusolver/cusolver_batch.cpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -494,9 +494,9 @@ sycl::event potrf_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t
494494
}
495495

496496
template <typename Func, typename T>
497-
inline sycl::event potrf_batch(Func func, sycl::queue &queue, oneapi::mkl::uplo *uplo,
498-
std::int64_t *n, T **a, std::int64_t *lda, std::int64_t group_count,
499-
std::int64_t *group_sizes, T *scratchpad,
497+
inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &queue,
498+
oneapi::mkl::uplo *uplo, std::int64_t *n, T **a, std::int64_t *lda,
499+
std::int64_t group_count, std::int64_t *group_sizes, T *scratchpad,
500500
std::int64_t scratchpad_size,
501501
const std::vector<sycl::event> &dependencies) {
502502
using cuDataType = typename CudaEquivalentType<T>::Type;
@@ -523,8 +523,9 @@ inline sycl::event potrf_batch(Func func, sycl::queue &queue, oneapi::mkl::uplo
523523
cusolverStatus_t err;
524524
for (int64_t i = 0; i < group_count; i++) {
525525
auto **a_ = reinterpret_cast<cuDataType **>(a_dev);
526-
CUSOLVER_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(uplo[i]), (int)n[i],
527-
a_ + offset, (int)lda[i], nullptr, (int)group_sizes[i]);
526+
CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo[i]),
527+
(int)n[i], a_ + offset, (int)lda[i], nullptr,
528+
(int)group_sizes[i]);
528529
offset += group_sizes[i];
529530
}
530531
});
@@ -538,8 +539,8 @@ inline sycl::event potrf_batch(Func func, sycl::queue &queue, oneapi::mkl::uplo
538539
sycl::queue &queue, oneapi::mkl::uplo *uplo, std::int64_t *n, TYPE **a, std::int64_t *lda, \
539540
std::int64_t group_count, std::int64_t *group_sizes, TYPE *scratchpad, \
540541
std::int64_t scratchpad_size, const std::vector<sycl::event> &dependencies) { \
541-
return potrf_batch(CUSOLVER_ROUTINE, queue, uplo, n, a, lda, group_count, group_sizes, \
542-
scratchpad, scratchpad_size, dependencies); \
542+
return potrf_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, uplo, n, a, lda, \
543+
group_count, group_sizes, scratchpad, scratchpad_size, dependencies); \
543544
}
544545

545546
POTRF_BATCH_LAUNCHER_USM(float, cusolverDnSpotrfBatched)
@@ -581,10 +582,10 @@ sycl::event potrs_batch(sycl::queue &queue, oneapi::mkl::uplo uplo, std::int64_t
581582
}
582583

583584
template <typename Func, typename T>
584-
inline sycl::event potrs_batch(Func func, sycl::queue &queue, oneapi::mkl::uplo *uplo,
585-
std::int64_t *n, std::int64_t *nrhs, T **a, std::int64_t *lda, T **b,
586-
std::int64_t *ldb, std::int64_t group_count,
587-
std::int64_t *group_sizes, T *scratchpad,
585+
inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &queue,
586+
oneapi::mkl::uplo *uplo, std::int64_t *n, std::int64_t *nrhs, T **a,
587+
std::int64_t *lda, T **b, std::int64_t *ldb,
588+
std::int64_t group_count, std::int64_t *group_sizes, T *scratchpad,
588589
std::int64_t scratchpad_size,
589590
const std::vector<sycl::event> &dependencies) {
590591
using cuDataType = typename CudaEquivalentType<T>::Type;
@@ -624,9 +625,9 @@ inline sycl::event potrs_batch(Func func, sycl::queue &queue, oneapi::mkl::uplo
624625
auto **a_ = reinterpret_cast<cuDataType **>(a_dev);
625626
auto **b_ = reinterpret_cast<cuDataType **>(b_dev);
626627
auto info_ = reinterpret_cast<int *>(info);
627-
CUSOLVER_ERROR_FUNC(func, err, handle, get_cublas_fill_mode(uplo[i]), (int)n[i],
628-
(int)nrhs[i], a_ + offset, (int)lda[i], b_ + offset,
629-
(int)ldb[i], info_, (int)group_sizes[i]);
628+
CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_fill_mode(uplo[i]),
629+
(int)n[i], (int)nrhs[i], a_ + offset, (int)lda[i],
630+
b_ + offset, (int)ldb[i], info_, (int)group_sizes[i]);
630631
offset += group_sizes[i];
631632
}
632633
});
@@ -635,14 +636,15 @@ inline sycl::event potrs_batch(Func func, sycl::queue &queue, oneapi::mkl::uplo
635636
}
636637

637638
// Scratchpad memory not needed as parts of buffer a is used as workspace memory
638-
#define POTRS_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \
639-
sycl::event potrs_batch( \
640-
sycl::queue &queue, oneapi::mkl::uplo *uplo, std::int64_t *n, std::int64_t *nrhs, \
641-
TYPE **a, std::int64_t *lda, TYPE **b, std::int64_t *ldb, std::int64_t group_count, \
642-
std::int64_t *group_sizes, TYPE *scratchpad, std::int64_t scratchpad_size, \
643-
const std::vector<sycl::event> &dependencies) { \
644-
return potrs_batch(CUSOLVER_ROUTINE, queue, uplo, n, nrhs, a, lda, b, ldb, group_count, \
645-
group_sizes, scratchpad, scratchpad_size, dependencies); \
639+
#define POTRS_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \
640+
sycl::event potrs_batch( \
641+
sycl::queue &queue, oneapi::mkl::uplo *uplo, std::int64_t *n, std::int64_t *nrhs, \
642+
TYPE **a, std::int64_t *lda, TYPE **b, std::int64_t *ldb, std::int64_t group_count, \
643+
std::int64_t *group_sizes, TYPE *scratchpad, std::int64_t scratchpad_size, \
644+
const std::vector<sycl::event> &dependencies) { \
645+
return potrs_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, uplo, n, nrhs, a, lda, b, \
646+
ldb, group_count, group_sizes, scratchpad, scratchpad_size, \
647+
dependencies); \
646648
}
647649

648650
POTRS_BATCH_LAUNCHER_USM(float, cusolverDnSpotrsBatched)

src/lapack/backends/cusolver/cusolver_helper.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131

3232
#include "oneapi/mkl/types.hpp"
3333
#include "runtime_support_helper.hpp"
34+
#include "oneapi/mkl/exceptions.hpp"
35+
#include "oneapi/mkl/lapack/exceptions.hpp"
3436

3537
namespace oneapi {
3638
namespace mkl {
@@ -173,6 +175,12 @@ class cuda_error : virtual public std::runtime_error {
173175
throw cusolver_error(std::string(#name) + std::string(" : "), err); \
174176
}
175177

178+
#define CUSOLVER_ERROR_FUNC_T(name, func, err, ...) \
179+
err = func(__VA_ARGS__); \
180+
if (err != CUSOLVER_STATUS_SUCCESS) { \
181+
throw cusolver_error(std::string(name) + std::string(" : "), err); \
182+
}
183+
176184
inline cusolverEigType_t get_cusolver_itype(std::int64_t itype) {
177185
switch (itype) {
178186
case 1: return CUSOLVER_EIG_TYPE_1;
@@ -251,6 +259,30 @@ struct CudaEquivalentType<std::complex<double>> {
251259
using Type = cuDoubleComplex;
252260
};
253261

262+
/* devinfo */
263+
264+
inline int get_cusolver_devinfo(sycl::queue &queue, sycl::buffer<int> &devInfo) {
265+
sycl::host_accessor<int, 1, sycl::access::mode::read> dev_info_{ devInfo };
266+
return dev_info_[0];
267+
}
268+
269+
inline int get_cusolver_devinfo(sycl::queue &queue, const int *devInfo) {
270+
int dev_info_;
271+
queue.wait();
272+
queue.memcpy(&dev_info_, devInfo, sizeof(int));
273+
return dev_info_;
274+
}
275+
276+
template <typename DEVINFO_T>
277+
inline void lapack_info_check(sycl::queue &queue, DEVINFO_T devinfo, const char *func_name,
278+
const char *cufunc_name) {
279+
const int devinfo_ = get_cusolver_devinfo(queue, devinfo);
280+
if (devinfo_ > 0)
281+
throw oneapi::mkl::lapack::computation_error(
282+
func_name, std::string(cufunc_name) + " failed with info = " + std::to_string(devinfo_),
283+
devinfo_);
284+
}
285+
254286
} // namespace cusolver
255287
} // namespace lapack
256288
} // namespace mkl

0 commit comments

Comments
 (0)