Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/blas/backends/generic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ endif()
# If find_package doesn't work, download onemath_sycl_blas from Github. This is
# intended to make oneMath easier to use.
message(STATUS "Looking for oneMATH blas kernels")
find_package(ONEMATH_SYCL_BLAS QUIET)
find_package(ONEMATH_SYCL_BLAS 0.2.0 QUIET)
if (NOT ONEMATH_SYCL_BLAS_FOUND)
message(STATUS "Looking for onemath_sycl_blas for generic backend - could not find onemath_sycl_blas with ONEMATH_SYCL_BLAS_DIR")
include(FetchContent)
Expand All @@ -150,7 +150,6 @@ if (NOT ONEMATH_SYCL_BLAS_FOUND)
endif()
# Following variable TUNING_TARGET will be used in generic blas internal configuration
set(TUNING_TARGET ${GENERIC_BLAS_TUNING_TARGET})
set(BLAS_ENABLE_COMPLEX ON)
# Set the policy to forward variables to generic blas configure step
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
set(FETCHCONTENT_BASE_DIR "${CMAKE_BINARY_DIR}/deps")
Expand All @@ -161,13 +160,21 @@ if (NOT ONEMATH_SYCL_BLAS_FOUND)
SOURCE_SUBDIR onemath/sycl/blas
)
FetchContent_MakeAvailable(onemath_sycl_blas)
install(
TARGETS onemath_sycl_blas
EXPORT oneMathTargets
)
message(STATUS "Looking for onemath_sycl_blas - downloaded")

else()
message(STATUS "Looking for oneMath blas kernels - found")
add_library(onemath_sycl_blas ALIAS ONEMATH_SYCL_BLAS::onemath_sycl_blas)
endif()

# Read cmake options exported by the onemath_sycl_blas project into oneMath variables
set(ONEMATH_GENERIC_BLAS_ENABLE_COMPLEX ${BLAS_ENABLE_COMPLEX} CACHE INTERNAL "Enable complex data support")
set(ONEMATH_GENERIC_BLAS_ENABLE_USM ${BLAS_ENABLE_USM} CACHE INTERNAL "Enable USM API support")

set(SOURCES
generic_level1_double.cpp generic_level1_float.cpp
generic_level2_double.cpp generic_level2_float.cpp
Expand Down
12 changes: 12 additions & 0 deletions src/blas/backends/generic/generic_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#define _GENERIC_BLAS_COMMON_HPP_

#include "onemath_sycl_blas.hpp"
#include "oneapi/math/detail/config.hpp"
#include "oneapi/math/types.hpp"
#include "oneapi/math/exceptions.hpp"

Expand All @@ -40,9 +41,11 @@ using handle_t = ::blas::SB_Handle;
template <typename ElemT>
using buffer_iterator_t = ::blas::BufferIterator<ElemT>;

#ifdef ONEMATH_GENERIC_BLAS_ENABLE_COMPLEX
// sycl complex data type (experimental)
template <typename ElemT>
using sycl_complex_t = sycl::ext::oneapi::experimental::complex<ElemT>;
#endif

/** A trait for obtaining equivalent onemath_sycl_blas API types from oneMath API
* types.
Expand All @@ -68,8 +71,10 @@ DEF_GENERIC_BLAS_TYPE(oneapi::math::transpose, char)
DEF_GENERIC_BLAS_TYPE(oneapi::math::uplo, char)
DEF_GENERIC_BLAS_TYPE(oneapi::math::side, char)
DEF_GENERIC_BLAS_TYPE(oneapi::math::diag, char)
#ifdef ONEMATH_GENERIC_BLAS_ENABLE_COMPLEX
DEF_GENERIC_BLAS_TYPE(std::complex<float>, sycl_complex_t<float>)
DEF_GENERIC_BLAS_TYPE(std::complex<double>, sycl_complex_t<double>)
#endif
// Passthrough of onemath_sycl_blas arg types for more complex wrapping.
DEF_GENERIC_BLAS_TYPE(::blas::gemm_batch_type_t, ::blas::gemm_batch_type_t)

Expand All @@ -85,6 +90,7 @@ struct generic_type<ElemT*> {
using type = ElemT*;
};

#ifdef ONEMATH_GENERIC_BLAS_ENABLE_COMPLEX
// USM Complex
template <typename ElemT>
struct generic_type<std::complex<ElemT>*> {
Expand All @@ -95,6 +101,7 @@ template <typename ElemT>
struct generic_type<const std::complex<ElemT>*> {
using type = const sycl_complex_t<ElemT>*;
};
#endif

template <>
struct generic_type<std::vector<sycl::event>> {
Expand Down Expand Up @@ -210,6 +217,10 @@ struct throw_if_unsupported_by_device {
throw unimplemented("blas", "onemath_sycl_blas function"); \
}

#ifndef ONEMATH_GENERIC_BLAS_ENABLE_USM
#define CALL_GENERIC_BLAS_USM_FN(genericFunc, ...) \
throw unimplemented("blas", "onemath_sycl_blas USM API", "- unsupported compiler");
#else
#define CALL_GENERIC_BLAS_USM_FN(genericFunc, ...) \
if constexpr (is_column_major()) { \
detail::throw_if_unsupported_by_device<double, sycl::aspect::fp64>{}( \
Expand All @@ -230,6 +241,7 @@ struct throw_if_unsupported_by_device {
else { \
throw unimplemented("blas", "onemath_sycl_blas function"); \
}
#endif

} // namespace generic
} // namespace blas
Expand Down
5 changes: 5 additions & 0 deletions src/blas/backends/generic/generic_level3.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ void gemm(sycl::queue& queue, oneapi::math::transpose transa, oneapi::math::tran
sycl::buffer<std::complex<real_t>, 1>& a, std::int64_t lda,
sycl::buffer<std::complex<real_t>, 1>& b, std::int64_t ldb, std::complex<real_t> beta,
sycl::buffer<std::complex<real_t>, 1>& c, std::int64_t ldc) {
#ifndef ONEMATH_GENERIC_BLAS_ENABLE_COMPLEX
throw unimplemented("blas", "onemath_sycl_blas gemm with complex data type",
"- unsupported compiler");
#else
using sycl_complex_real_t = sycl::ext::oneapi::experimental::complex<real_t>;
if (transa == oneapi::math::transpose::conjtrans ||
transb == oneapi::math::transpose::conjtrans) {
Expand Down Expand Up @@ -62,6 +66,7 @@ void gemm(sycl::queue& queue, oneapi::math::transpose transa, oneapi::math::tran
sycl::accessor<std::complex<real_t>, 1, sycl::access::mode::write> out_acc(c);
sycl::accessor<sycl_complex_real_t, 1, sycl::access::mode::read> out_pb_acc(c_pb);
queue.copy(out_pb_acc, out_acc);
#endif
}

void symm(sycl::queue& queue, oneapi::math::side left_right, oneapi::math::uplo upper_lower,
Expand Down
2 changes: 2 additions & 0 deletions src/config.hpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
#cmakedefine ONEMATH_ENABLE_GENERIC_BLAS_BACKEND_INTEL_CPU
#cmakedefine ONEMATH_ENABLE_GENERIC_BLAS_BACKEND_INTEL_GPU
#cmakedefine ONEMATH_ENABLE_GENERIC_BLAS_BACKEND_NVIDIA_GPU
#cmakedefine ONEMATH_GENERIC_BLAS_ENABLE_COMPLEX
#cmakedefine ONEMATH_GENERIC_BLAS_ENABLE_USM
#cmakedefine ONEMATH_ENABLE_PORTFFT_BACKEND
#cmakedefine ONEMATH_ENABLE_ROCBLAS_BACKEND
#cmakedefine ONEMATH_ENABLE_ROCFFT_BACKEND
Expand Down
Loading