Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/dft/backends/mklgpu/backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ inline auto compute_backward(dft::detail::descriptor<prec, dom>& desc, ArgTs&&..
}
auto handle = reinterpret_cast<handle_t*>(commit_handle->get_handle());
auto mklgpu_desc = handle->second; // Second because backward DFT.
int commit_status{ DFTI_UNCOMMITTED };
auto commit_status = uncommitted;
mklgpu_desc->get_value(oneapi::mkl::dft::config_param::COMMIT_STATUS, &commit_status);
if (commit_status != DFTI_COMMITTED) {
if (commit_status != committed) {
throw math::invalid_argument("DFT", "compute_backward",
"MKLGPU DFT descriptor was not successfully committed.");
}
Expand Down
50 changes: 33 additions & 17 deletions src/dft/backends/mklgpu/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,20 @@
#include <mkl_version.h>
#if INTEL_MKL_VERSION < 20250000
#include <mkl/dfti.hpp>
namespace oneapi::math::dft::mklgpu::detail {
template <typename D, typename C, typename V>
void set_vector_value(D& desc, C p, V const& vec) {
desc.set_value(p, vec.data());
}
} // namespace oneapi::math::dft::mklgpu::detail
#else
#include <mkl/dft.hpp>
namespace oneapi::math::dft::mklgpu::detail {
template <typename D, typename C, typename V>
void set_vector_value(D& desc, C p, V const& vec) {
desc.set_value(p, vec);
}
} // namespace oneapi::math::dft::mklgpu::detail
#endif

// Intel oneMKL 2024.1 deprecates input/output strides.
Expand Down Expand Up @@ -159,14 +171,19 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
desc.set_value(backend_param::FORWARD_SCALE, config.fwd_scale);
desc.set_value(backend_param::BACKWARD_SCALE, config.bwd_scale);
desc.set_value(backend_param::NUMBER_OF_TRANSFORMS, config.number_of_transforms);
desc.set_value(backend_param::COMPLEX_STORAGE,
to_mklgpu<onemath_param::COMPLEX_STORAGE>(config.complex_storage));
if (config.complex_storage != dft::detail::config_value::COMPLEX_COMPLEX) {
throw math::unimplemented("dft/backends/mklgpu", "commit",
"MKLGPU only supports complex-complex complex storage.");
}
if (config.real_storage != dft::detail::config_value::REAL_REAL) {
throw math::invalid_argument("dft/backends/mklgpu", "commit",
"MKLGPU only supports real-real real storage.");
throw math::unimplemented("dft/backends/mklgpu", "commit",
"MKLGPU only supports real-real real storage.");
}
if (config.conj_even_storage != dft::detail::config_value::COMPLEX_COMPLEX) {
throw math::unimplemented(
"dft/backends/mklgpu", "commit",
"MKLGPU only supports complex-complex conjugate even storage.");
}
desc.set_value(backend_param::CONJUGATE_EVEN_STORAGE,
to_mklgpu<onemath_param::CONJUGATE_EVEN_STORAGE>(config.conj_even_storage));
desc.set_value(backend_param::PLACEMENT,
to_mklgpu<onemath_param::PLACEMENT>(config.placement));

Expand All @@ -175,30 +192,29 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
throw math::unimplemented("dft/backends/mklgpu", "commit",
"MKLGPU does not support nonzero offsets.");
}
desc.set_value(backend_param::FWD_STRIDES, config.fwd_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.bwd_strides.data());
set_vector_value(desc, backend_param::FWD_STRIDES, config.fwd_strides);
set_vector_value(desc, backend_param::BWD_STRIDES, config.bwd_strides);
}
else {
if (config.input_strides[0] != 0 || config.output_strides[0] != 0) {
throw math::unimplemented("dft/backends/mklgpu", "commit",
"MKLGPU does not support nonzero offsets.");
}
if (assume_fwd_dft) {
desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data());
set_vector_value(desc, backend_param::FWD_STRIDES, config.input_strides);
set_vector_value(desc, backend_param::BWD_STRIDES, config.output_strides);
}
else {
desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data());
set_vector_value(desc, backend_param::FWD_STRIDES, config.output_strides);
set_vector_value(desc, backend_param::BWD_STRIDES, config.input_strides);
}
}
desc.set_value(backend_param::FWD_DISTANCE, config.fwd_dist);
desc.set_value(backend_param::BWD_DISTANCE, config.bwd_dist);
if (config.workspace_placement == dft::detail::config_value::WORKSPACE_EXTERNAL) {
// Setting WORKSPACE_INTERNAL (default) causes FFT_INVALID_DESCRIPTOR.
desc.set_value(backend_param::WORKSPACE,
to_mklgpu_config_value<onemath_param::WORKSPACE_PLACEMENT>(
config.workspace_placement));
desc.set_value(backend_param::WORKSPACE, to_mklgpu<onemath_param::WORKSPACE_PLACEMENT>(
config.workspace_placement));
}
// Setting the ordering causes an FFT_INVALID_DESCRIPTOR. Check that default is used:
if (config.ordering != dft::detail::config_value::ORDERED) {
Expand All @@ -214,11 +230,11 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {

// This is called by the workspace_helper, and is not part of the user API.
virtual std::int64_t get_workspace_external_bytes_impl() override {
std::size_t workspaceSizeFwd = 0, workspaceSizeBwd = 0;
std::int64_t workspaceSizeFwd = 0, workspaceSizeBwd = 0;
using backend_param = oneapi::mkl::dft::config_param;
handle.first->get_value(backend_param::WORKSPACE_BYTES, &workspaceSizeFwd);
handle.second->get_value(backend_param::WORKSPACE_BYTES, &workspaceSizeBwd);
return static_cast<std::int64_t>(std::max(workspaceSizeFwd, workspaceSizeFwd));
return std::max(workspaceSizeFwd, workspaceSizeFwd);
}
};
} // namespace detail
Expand Down
4 changes: 2 additions & 2 deletions src/dft/backends/mklgpu/forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ inline auto compute_forward(dft::detail::descriptor<prec, dom>& desc, ArgTs&&...
}
auto handle = reinterpret_cast<handle_t*>(commit_handle->get_handle());
auto mklgpu_desc = handle->first; // First because forward DFT.
int commit_status{ DFTI_UNCOMMITTED };
auto commit_status = uncommitted;
mklgpu_desc->get_value(oneapi::mkl::dft::config_param::COMMIT_STATUS, &commit_status);
if (commit_status != DFTI_COMMITTED) {
if (commit_status != committed) {
throw math::invalid_argument("DFT", "compute_forward",
"MKLGPU DFT descriptor was not successfully committed.");
}
Expand Down
153 changes: 51 additions & 102 deletions src/dft/backends/mklgpu/mklgpu_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,16 @@
#include <mkl_version.h>
#if INTEL_MKL_VERSION < 20250000
#include <mkl/dfti.hpp>
namespace oneapi::math::dft::mklgpu::detail {
constexpr int committed = DFTI_COMMITTED;
constexpr int uncommitted = DFTI_UNCOMMITTED;
} // namespace oneapi::math::dft::mklgpu::detail
#else
#include <mkl/dft.hpp>
namespace oneapi::math::dft::mklgpu::detail {
constexpr auto committed = oneapi::mkl::dft::config_value::COMMITTED;
constexpr auto uncommitted = oneapi::mkl::dft::config_value::UNCOMMITTED;
} // namespace oneapi::math::dft::mklgpu::detail
#endif

namespace oneapi {
Expand Down Expand Up @@ -57,120 +65,61 @@ inline constexpr oneapi::mkl::dft::precision to_mklgpu(dft::detail::precision do
}
}

/// Convert a config_param to equivalent backend native value.
inline constexpr oneapi::mkl::dft::config_param to_mklgpu(dft::detail::config_param param) {
using iparam = dft::detail::config_param;
using oparam = oneapi::mkl::dft::config_param;
switch (param) {
case iparam::FORWARD_DOMAIN: return oparam::FORWARD_DOMAIN;
case iparam::DIMENSION: return oparam::DIMENSION;
case iparam::LENGTHS: return oparam::LENGTHS;
case iparam::PRECISION: return oparam::PRECISION;
case iparam::FORWARD_SCALE: return oparam::FORWARD_SCALE;
case iparam::NUMBER_OF_TRANSFORMS: return oparam::NUMBER_OF_TRANSFORMS;
case iparam::COMPLEX_STORAGE: return oparam::COMPLEX_STORAGE;
case iparam::CONJUGATE_EVEN_STORAGE: return oparam::CONJUGATE_EVEN_STORAGE;
case iparam::FWD_DISTANCE: return oparam::FWD_DISTANCE;
case iparam::BWD_DISTANCE: return oparam::BWD_DISTANCE;
case iparam::WORKSPACE: return oparam::WORKSPACE;
case iparam::PACKED_FORMAT: return oparam::PACKED_FORMAT;
case iparam::WORKSPACE_PLACEMENT: return oparam::WORKSPACE; // Same as WORKSPACE
case iparam::WORKSPACE_EXTERNAL_BYTES: return oparam::WORKSPACE_BYTES;
case iparam::COMMIT_STATUS: return oparam::COMMIT_STATUS;
default:
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config param.");
return static_cast<oparam>(0);
}
}
template <dft::detail::config_param Param>
struct to_mklgpu_impl;

/** Convert a config_value to the backend's native value. Throw on invalid input.
* @tparam Param The config param the value is for.
* @param value The config value to convert.
**/
template <dft::detail::config_param Param>
inline constexpr int to_mklgpu(dft::detail::config_value value);

template <>
inline constexpr int to_mklgpu<dft::detail::config_param::COMPLEX_STORAGE>(
dft::detail::config_value value) {
if (value == dft::detail::config_value::COMPLEX_COMPLEX) {
return DFTI_COMPLEX_COMPLEX;
}
else {
throw math::unimplemented("dft", "MKLGPU descriptor set_value()",
"MKLGPU only supports complex-complex for complex storage.");
return 0;
}
}

template <>
inline constexpr int to_mklgpu<dft::detail::config_param::CONJUGATE_EVEN_STORAGE>(
dft::detail::config_value value) {
if (value == dft::detail::config_value::COMPLEX_COMPLEX) {
return DFTI_COMPLEX_COMPLEX;
}
else {
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for conjugate even storage.");
return 0;
}
inline constexpr auto to_mklgpu(dft::detail::config_value value) {
return to_mklgpu_impl<Param>{}(value);
}

#if INTEL_MKL_VERSION < 20250000
template <>
inline constexpr int to_mklgpu<dft::detail::config_param::PLACEMENT>(
dft::detail::config_value value) {
if (value == dft::detail::config_value::INPLACE) {
return DFTI_INPLACE;
}
else if (value == dft::detail::config_value::NOT_INPLACE) {
return DFTI_NOT_INPLACE;
}
else {
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for inplace.");
return 0;
}
}

struct to_mklgpu_impl<dft::detail::config_param::PLACEMENT> {
inline constexpr auto operator()(dft::detail::config_value value) -> int {
switch (value) {
case dft::detail::config_value::INPLACE: return DFTI_INPLACE;
case dft::detail::config_value::NOT_INPLACE: return DFTI_NOT_INPLACE;
default:
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for inplace.");
}
}
};
#else
template <>
inline constexpr int to_mklgpu<dft::detail::config_param::PACKED_FORMAT>(
dft::detail::config_value value) {
if (value == dft::detail::config_value::CCE_FORMAT) {
return DFTI_CCE_FORMAT;
}
else {
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for packed format.");
return 0;
}
}

/** Convert a config_value to the backend's native value. Throw on invalid input.
* @tparam Param The config param the value is for.
* @param value The config value to convert.
**/
template <dft::detail::config_param Param>
inline constexpr oneapi::mkl::dft::config_value to_mklgpu_config_value(
dft::detail::config_value value);
struct to_mklgpu_impl<dft::detail::config_param::PLACEMENT> {
inline constexpr auto operator()(dft::detail::config_value value) {
switch (value) {
case dft::detail::config_value::INPLACE: return oneapi::mkl::dft::config_value::INPLACE;
case dft::detail::config_value::NOT_INPLACE:
return oneapi::mkl::dft::config_value::NOT_INPLACE;
default:
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for inplace.");
}
}
};
#endif

template <>
inline constexpr oneapi::mkl::dft::config_value
to_mklgpu_config_value<dft::detail::config_param::WORKSPACE_PLACEMENT>(
dft::detail::config_value value) {
if (value == dft::detail::config_value::WORKSPACE_AUTOMATIC) {
// NB: oneapi::mkl::dft::config_value != dft::detail::config_value
return oneapi::mkl::dft::config_value::WORKSPACE_INTERNAL;
}
else if (value == dft::detail::config_value::WORKSPACE_EXTERNAL) {
return oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL;
}
else {
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for workspace placement.");
return oneapi::mkl::dft::config_value::WORKSPACE_INTERNAL;
}
}
struct to_mklgpu_impl<dft::detail::config_param::WORKSPACE_PLACEMENT> {
inline constexpr auto operator()(dft::detail::config_value value) {
switch (value) {
case dft::detail::config_value::WORKSPACE_AUTOMATIC:
return oneapi::mkl::dft::config_value::WORKSPACE_INTERNAL;
case dft::detail::config_value::WORKSPACE_EXTERNAL:
return oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL;
default:
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for inplace.");
}
}
};
} // namespace detail
} // namespace mklgpu
} // namespace dft
Expand Down
Loading