Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/dft/backends/mklgpu/backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ inline auto compute_backward(dft::detail::descriptor<prec, dom>& desc, ArgTs&&..
}
auto handle = reinterpret_cast<handle_t*>(commit_handle->get_handle());
auto mklgpu_desc = handle->second; // Second because backward DFT.
int commit_status{ DFTI_UNCOMMITTED };
auto commit_status = uncommitted;
mklgpu_desc->get_value(oneapi::mkl::dft::config_param::COMMIT_STATUS, &commit_status);
if (commit_status != DFTI_COMMITTED) {
if (commit_status != committed) {
throw math::invalid_argument("DFT", "compute_backward",
"MKLGPU DFT descriptor was not successfully committed.");
}
Expand Down
50 changes: 33 additions & 17 deletions src/dft/backends/mklgpu/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,20 @@
#include <mkl_version.h>
#if INTEL_MKL_VERSION < 20250000
#include <mkl/dfti.hpp>
namespace oneapi::math::dft::mklgpu::detail {
template <typename D, typename C, typename V>
void set_vector_value(D& desc, C p, V const& vec) {
desc.set_value(p, vec.data());
}
} // namespace oneapi::math::dft::mklgpu::detail
#else
#include <mkl/dft.hpp>
namespace oneapi::math::dft::mklgpu::detail {
template <typename D, typename C, typename V>
void set_vector_value(D& desc, C p, V const& vec) {
desc.set_value(p, vec);
}
} // namespace oneapi::math::dft::mklgpu::detail
#endif

// Intel oneMKL 2024.1 deprecates input/output strides.
Expand Down Expand Up @@ -159,14 +171,19 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
desc.set_value(backend_param::FORWARD_SCALE, config.fwd_scale);
desc.set_value(backend_param::BACKWARD_SCALE, config.bwd_scale);
desc.set_value(backend_param::NUMBER_OF_TRANSFORMS, config.number_of_transforms);
desc.set_value(backend_param::COMPLEX_STORAGE,
to_mklgpu<onemath_param::COMPLEX_STORAGE>(config.complex_storage));
if (config.complex_storage != dft::detail::config_value::COMPLEX_COMPLEX) {
throw math::unimplemented("dft/backends/mklgpu", "commit",
"MKLGPU only supports complex-complex complex storage.");
}
if (config.real_storage != dft::detail::config_value::REAL_REAL) {
throw math::invalid_argument("dft/backends/mklgpu", "commit",
"MKLGPU only supports real-real real storage.");
throw math::unimplemented("dft/backends/mklgpu", "commit",
"MKLGPU only supports real-real real storage.");
}
if (config.conj_even_storage != dft::detail::config_value::COMPLEX_COMPLEX) {
throw math::unimplemented(
"dft/backends/mklgpu", "commit",
"MKLGPU only supports complex-complex conjugate even storage.");
}
desc.set_value(backend_param::CONJUGATE_EVEN_STORAGE,
to_mklgpu<onemath_param::CONJUGATE_EVEN_STORAGE>(config.conj_even_storage));
desc.set_value(backend_param::PLACEMENT,
to_mklgpu<onemath_param::PLACEMENT>(config.placement));

Expand All @@ -175,30 +192,29 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
throw math::unimplemented("dft/backends/mklgpu", "commit",
"MKLGPU does not support nonzero offsets.");
}
desc.set_value(backend_param::FWD_STRIDES, config.fwd_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.bwd_strides.data());
set_vector_value(desc, backend_param::FWD_STRIDES, config.fwd_strides);
set_vector_value(desc, backend_param::BWD_STRIDES, config.bwd_strides);
}
else {
if (config.input_strides[0] != 0 || config.output_strides[0] != 0) {
throw math::unimplemented("dft/backends/mklgpu", "commit",
"MKLGPU does not support nonzero offsets.");
}
if (assume_fwd_dft) {
desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data());
set_vector_value(desc, backend_param::FWD_STRIDES, config.input_strides);
set_vector_value(desc, backend_param::BWD_STRIDES, config.output_strides);
}
else {
desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data());
set_vector_value(desc, backend_param::FWD_STRIDES, config.output_strides);
set_vector_value(desc, backend_param::BWD_STRIDES, config.input_strides);
}
}
desc.set_value(backend_param::FWD_DISTANCE, config.fwd_dist);
desc.set_value(backend_param::BWD_DISTANCE, config.bwd_dist);
if (config.workspace_placement == dft::detail::config_value::WORKSPACE_EXTERNAL) {
// Setting WORKSPACE_INTERNAL (default) causes FFT_INVALID_DESCRIPTOR.
desc.set_value(backend_param::WORKSPACE,
to_mklgpu_config_value<onemath_param::WORKSPACE_PLACEMENT>(
config.workspace_placement));
desc.set_value(backend_param::WORKSPACE, to_mklgpu<onemath_param::WORKSPACE_PLACEMENT>(
config.workspace_placement));
}
// Setting the ordering causes an FFT_INVALID_DESCRIPTOR. Check that default is used:
if (config.ordering != dft::detail::config_value::ORDERED) {
Expand All @@ -214,11 +230,11 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {

// This is called by the workspace_helper, and is not part of the user API.
virtual std::int64_t get_workspace_external_bytes_impl() override {
std::size_t workspaceSizeFwd = 0, workspaceSizeBwd = 0;
std::int64_t workspaceSizeFwd = 0, workspaceSizeBwd = 0;
using backend_param = oneapi::mkl::dft::config_param;
handle.first->get_value(backend_param::WORKSPACE_BYTES, &workspaceSizeFwd);
handle.second->get_value(backend_param::WORKSPACE_BYTES, &workspaceSizeBwd);
return static_cast<std::int64_t>(std::max(workspaceSizeFwd, workspaceSizeFwd));
return std::max(workspaceSizeFwd, workspaceSizeFwd);
}
};
} // namespace detail
Expand Down
4 changes: 2 additions & 2 deletions src/dft/backends/mklgpu/forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ inline auto compute_forward(dft::detail::descriptor<prec, dom>& desc, ArgTs&&...
}
auto handle = reinterpret_cast<handle_t*>(commit_handle->get_handle());
auto mklgpu_desc = handle->first; // First because forward DFT.
int commit_status{ DFTI_UNCOMMITTED };
auto commit_status = uncommitted;
mklgpu_desc->get_value(oneapi::mkl::dft::config_param::COMMIT_STATUS, &commit_status);
if (commit_status != DFTI_COMMITTED) {
if (commit_status != committed) {
throw math::invalid_argument("DFT", "compute_forward",
"MKLGPU DFT descriptor was not successfully committed.");
}
Expand Down
132 changes: 54 additions & 78 deletions src/dft/backends/mklgpu/mklgpu_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,16 @@
#include <mkl_version.h>
#if INTEL_MKL_VERSION < 20250000
#include <mkl/dfti.hpp>
namespace oneapi::math::dft::mklgpu::detail {
constexpr int committed = DFTI_COMMITTED;
constexpr int uncommitted = DFTI_UNCOMMITTED;
} // namespace oneapi::math::dft::mklgpu::detail
#else
#include <mkl/dft.hpp>
namespace oneapi::math::dft::mklgpu::detail {
constexpr auto committed = oneapi::mkl::dft::config_value::COMMITTED;
constexpr auto uncommitted = oneapi::mkl::dft::config_value::UNCOMMITTED;
} // namespace oneapi::math::dft::mklgpu::detail
#endif

namespace oneapi {
Expand Down Expand Up @@ -58,7 +66,7 @@ inline constexpr oneapi::mkl::dft::precision to_mklgpu(dft::detail::precision do
}

/// Convert a config_param to equivalent backend native value.
inline constexpr oneapi::mkl::dft::config_param to_mklgpu(dft::detail::config_param param) {
/*inline constexpr oneapi::mkl::dft::config_param to_mklgpu(dft::detail::config_param param) {
using iparam = dft::detail::config_param;
using oparam = oneapi::mkl::dft::config_param;
switch (param) {
Expand All @@ -82,95 +90,63 @@ inline constexpr oneapi::mkl::dft::config_param to_mklgpu(dft::detail::config_pa
"Invalid config param.");
return static_cast<oparam>(0);
}
}
}*/

template <dft::detail::config_param Param>
struct to_mklgpu_impl;

/** Convert a config_value to the backend's native value. Throw on invalid input.
* @tparam Param The config param the value is for.
* @param value The config value to convert.
**/
template <dft::detail::config_param Param>
inline constexpr int to_mklgpu(dft::detail::config_value value);

template <>
inline constexpr int to_mklgpu<dft::detail::config_param::COMPLEX_STORAGE>(
dft::detail::config_value value) {
if (value == dft::detail::config_value::COMPLEX_COMPLEX) {
return DFTI_COMPLEX_COMPLEX;
}
else {
throw math::unimplemented("dft", "MKLGPU descriptor set_value()",
"MKLGPU only supports complex-complex for complex storage.");
return 0;
}
}

template <>
inline constexpr int to_mklgpu<dft::detail::config_param::CONJUGATE_EVEN_STORAGE>(
dft::detail::config_value value) {
if (value == dft::detail::config_value::COMPLEX_COMPLEX) {
return DFTI_COMPLEX_COMPLEX;
}
else {
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for conjugate even storage.");
return 0;
}
inline constexpr auto to_mklgpu(dft::detail::config_value value) {
return to_mklgpu_impl<Param>{}(value);
}

#if INTEL_MKL_VERSION < 20250000
template <>
inline constexpr int to_mklgpu<dft::detail::config_param::PLACEMENT>(
dft::detail::config_value value) {
if (value == dft::detail::config_value::INPLACE) {
return DFTI_INPLACE;
}
else if (value == dft::detail::config_value::NOT_INPLACE) {
return DFTI_NOT_INPLACE;
}
else {
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for inplace.");
return 0;
}
}

struct to_mklgpu_impl<dft::detail::config_param::PLACEMENT> {
inline constexpr auto operator()(dft::detail::config_value value) -> int {
switch (value) {
case dft::detail::config_value::INPLACE: return DFTI_INPLACE;
case dft::detail::config_value::NOT_INPLACE: return DFTI_NOT_INPLACE;
default:
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for inplace.");
}
}
};
#else
template <>
inline constexpr int to_mklgpu<dft::detail::config_param::PACKED_FORMAT>(
dft::detail::config_value value) {
if (value == dft::detail::config_value::CCE_FORMAT) {
return DFTI_CCE_FORMAT;
}
else {
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for packed format.");
return 0;
}
}

/** Convert a config_value to the backend's native value. Throw on invalid input.
* @tparam Param The config param the value is for.
* @param value The config value to convert.
**/
template <dft::detail::config_param Param>
inline constexpr oneapi::mkl::dft::config_value to_mklgpu_config_value(
dft::detail::config_value value);
struct to_mklgpu_impl<dft::detail::config_param::PLACEMENT> {
inline constexpr auto operator()(dft::detail::config_value value) {
switch (value) {
case dft::detail::config_value::INPLACE: return oneapi::mkl::dft::config_value::INPLACE;
case dft::detail::config_value::NOT_INPLACE:
return oneapi::mkl::dft::config_value::NOT_INPLACE;
default:
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for inplace.");
}
}
};
#endif

template <>
inline constexpr oneapi::mkl::dft::config_value
to_mklgpu_config_value<dft::detail::config_param::WORKSPACE_PLACEMENT>(
dft::detail::config_value value) {
if (value == dft::detail::config_value::WORKSPACE_AUTOMATIC) {
// NB: oneapi::mkl::dft::config_value != dft::detail::config_value
return oneapi::mkl::dft::config_value::WORKSPACE_INTERNAL;
}
else if (value == dft::detail::config_value::WORKSPACE_EXTERNAL) {
return oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL;
}
else {
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for workspace placement.");
return oneapi::mkl::dft::config_value::WORKSPACE_INTERNAL;
}
}
struct to_mklgpu_impl<dft::detail::config_param::WORKSPACE_PLACEMENT> {
inline constexpr auto operator()(dft::detail::config_value value) {
switch (value) {
case dft::detail::config_value::WORKSPACE_AUTOMATIC:
return oneapi::mkl::dft::config_value::WORKSPACE_INTERNAL;
case dft::detail::config_value::WORKSPACE_EXTERNAL:
return oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL;
default:
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for inplace.");
}
}
};
} // namespace detail
} // namespace mklgpu
} // namespace dft
Expand Down
Loading