Skip to content

Commit 55352fe

Browse files
committed
[BLAS][generic] Fix compilation with AdaptiveCpp
The generic BLAS backend currently offers limited support for AdaptiveCpp, where: * complex data type is not supported * USM API is not supported Add the required protections to make the generic BLAS backend compile and run correctly in the capacity it offers with AdaptiveCpp. That is, make the buffer USM without complex data work fine. Throw the unimplemented exception for the unsupported features. Note: fixes on the backend side are also required to actually compile, see uxlfoundation/generic-sycl-components#7
1 parent 0260ff7 commit 55352fe

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

src/blas/backends/generic/generic_common.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ using handle_t = ::blas::SB_Handle;
4040
template <typename ElemT>
4141
using buffer_iterator_t = ::blas::BufferIterator<ElemT>;
4242

43+
#ifdef BLAS_ENABLE_COMPLEX
4344
// sycl complex data type (experimental)
4445
template <typename ElemT>
4546
using sycl_complex_t = sycl::ext::oneapi::experimental::complex<ElemT>;
47+
#endif
4648

4749
/** A trait for obtaining equivalent onemath_sycl_blas API types from oneMath API
4850
* types.
@@ -68,8 +70,10 @@ DEF_GENERIC_BLAS_TYPE(oneapi::math::transpose, char)
6870
DEF_GENERIC_BLAS_TYPE(oneapi::math::uplo, char)
6971
DEF_GENERIC_BLAS_TYPE(oneapi::math::side, char)
7072
DEF_GENERIC_BLAS_TYPE(oneapi::math::diag, char)
73+
#ifdef BLAS_ENABLE_COMPLEX
7174
DEF_GENERIC_BLAS_TYPE(std::complex<float>, sycl_complex_t<float>)
7275
DEF_GENERIC_BLAS_TYPE(std::complex<double>, sycl_complex_t<double>)
76+
#endif
7377
// Passthrough of onemath_sycl_blas arg types for more complex wrapping.
7478
DEF_GENERIC_BLAS_TYPE(::blas::gemm_batch_type_t, ::blas::gemm_batch_type_t)
7579

@@ -85,6 +89,7 @@ struct generic_type<ElemT*> {
8589
using type = ElemT*;
8690
};
8791

92+
#ifdef BLAS_ENABLE_COMPLEX
8893
// USM Complex
8994
template <typename ElemT>
9095
struct generic_type<std::complex<ElemT>*> {
@@ -95,6 +100,7 @@ template <typename ElemT>
95100
struct generic_type<const std::complex<ElemT>*> {
96101
using type = const sycl_complex_t<ElemT>*;
97102
};
103+
#endif
98104

99105
template <>
100106
struct generic_type<std::vector<sycl::event>> {
@@ -210,6 +216,10 @@ struct throw_if_unsupported_by_device {
210216
throw unimplemented("blas", "onemath_sycl_blas function"); \
211217
}
212218

219+
#ifndef SB_ENABLE_USM
220+
#define CALL_GENERIC_BLAS_USM_FN(genericFunc, ...) \
221+
throw unimplemented("blas", "onemath_sycl_blas USM API", "- unsupported compiler");
222+
#else
213223
#define CALL_GENERIC_BLAS_USM_FN(genericFunc, ...) \
214224
if constexpr (is_column_major()) { \
215225
detail::throw_if_unsupported_by_device<double, sycl::aspect::fp64>{}( \
@@ -230,6 +240,7 @@ struct throw_if_unsupported_by_device {
230240
else { \
231241
throw unimplemented("blas", "onemath_sycl_blas function"); \
232242
}
243+
#endif
233244

234245
} // namespace generic
235246
} // namespace blas

src/blas/backends/generic/generic_level3.cxx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ void gemm(sycl::queue& queue, oneapi::math::transpose transa, oneapi::math::tran
3232
sycl::buffer<std::complex<real_t>, 1>& a, std::int64_t lda,
3333
sycl::buffer<std::complex<real_t>, 1>& b, std::int64_t ldb, std::complex<real_t> beta,
3434
sycl::buffer<std::complex<real_t>, 1>& c, std::int64_t ldc) {
35+
#ifndef BLAS_ENABLE_COMPLEX
36+
throw unimplemented("blas", "onemath_sycl_blas gemm with complex data type", "- unsupported compiler");
37+
#else
3538
using sycl_complex_real_t = sycl::ext::oneapi::experimental::complex<real_t>;
3639
if (transa == oneapi::math::transpose::conjtrans ||
3740
transb == oneapi::math::transpose::conjtrans) {
@@ -62,6 +65,7 @@ void gemm(sycl::queue& queue, oneapi::math::transpose transa, oneapi::math::tran
6265
sycl::accessor<std::complex<real_t>, 1, sycl::access::mode::write> out_acc(c);
6366
sycl::accessor<sycl_complex_real_t, 1, sycl::access::mode::read> out_pb_acc(c_pb);
6467
queue.copy(out_pb_acc, out_acc);
68+
#endif
6569
}
6670

6771
void symm(sycl::queue& queue, oneapi::math::side left_right, oneapi::math::uplo upper_lower,

0 commit comments

Comments
 (0)