diff --git a/src/dft/backends/mklgpu/backward.cpp b/src/dft/backends/mklgpu/backward.cpp index 4af5c7005..84ad3ae66 100644 --- a/src/dft/backends/mklgpu/backward.cpp +++ b/src/dft/backends/mklgpu/backward.cpp @@ -56,9 +56,9 @@ inline auto compute_backward(dft::detail::descriptor& desc, ArgTs&&.. } auto handle = reinterpret_cast(commit_handle->get_handle()); auto mklgpu_desc = handle->second; // Second because backward DFT. - int commit_status{ DFTI_UNCOMMITTED }; + auto commit_status = uncommitted; mklgpu_desc->get_value(oneapi::mkl::dft::config_param::COMMIT_STATUS, &commit_status); - if (commit_status != DFTI_COMMITTED) { + if (commit_status != committed) { throw math::invalid_argument("DFT", "compute_backward", "MKLGPU DFT descriptor was not successfully committed."); } diff --git a/src/dft/backends/mklgpu/commit.cpp b/src/dft/backends/mklgpu/commit.cpp index c92f9667c..c27eeb1b6 100644 --- a/src/dft/backends/mklgpu/commit.cpp +++ b/src/dft/backends/mklgpu/commit.cpp @@ -40,8 +40,20 @@ #include #if INTEL_MKL_VERSION < 20250000 #include +namespace oneapi::math::dft::mklgpu::detail { +template +void set_vector_value(D& desc, C p, V const& vec) { + desc.set_value(p, vec.data()); +} +} // namespace oneapi::math::dft::mklgpu::detail #else #include +namespace oneapi::math::dft::mklgpu::detail { +template +void set_vector_value(D& desc, C p, V const& vec) { + desc.set_value(p, vec); +} +} // namespace oneapi::math::dft::mklgpu::detail #endif // Intel oneMKL 2024.1 deprecates input/output strides. @@ -159,14 +171,19 @@ class mklgpu_commit final : public dft::detail::commit_impl { desc.set_value(backend_param::FORWARD_SCALE, config.fwd_scale); desc.set_value(backend_param::BACKWARD_SCALE, config.bwd_scale); desc.set_value(backend_param::NUMBER_OF_TRANSFORMS, config.number_of_transforms); - desc.set_value(backend_param::COMPLEX_STORAGE, - to_mklgpu(config.complex_storage)); + if (config.complex_storage != dft::detail::config_value::COMPLEX_COMPLEX) { + throw math::unimplemented("dft/backends/mklgpu", "commit", + "MKLGPU only supports complex-complex complex storage."); + } if (config.real_storage != dft::detail::config_value::REAL_REAL) { - throw math::invalid_argument("dft/backends/mklgpu", "commit", - "MKLGPU only supports real-real real storage."); + throw math::unimplemented("dft/backends/mklgpu", "commit", + "MKLGPU only supports real-real real storage."); + } + if (config.conj_even_storage != dft::detail::config_value::COMPLEX_COMPLEX) { + throw math::unimplemented( + "dft/backends/mklgpu", "commit", + "MKLGPU only supports complex-complex conjugate even storage."); } - desc.set_value(backend_param::CONJUGATE_EVEN_STORAGE, - to_mklgpu(config.conj_even_storage)); desc.set_value(backend_param::PLACEMENT, to_mklgpu(config.placement)); @@ -175,8 +192,8 @@ class mklgpu_commit final : public dft::detail::commit_impl { throw math::unimplemented("dft/backends/mklgpu", "commit", "MKLGPU does not support nonzero offsets."); } - desc.set_value(backend_param::FWD_STRIDES, config.fwd_strides.data()); - desc.set_value(backend_param::BWD_STRIDES, config.bwd_strides.data()); + set_vector_value(desc, backend_param::FWD_STRIDES, config.fwd_strides); + set_vector_value(desc, backend_param::BWD_STRIDES, config.bwd_strides); } else { if (config.input_strides[0] != 0 || config.output_strides[0] != 0) { @@ -184,21 +201,20 @@ class mklgpu_commit final : public dft::detail::commit_impl { "MKLGPU does not support nonzero offsets."); } if (assume_fwd_dft) { - desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data()); - desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data()); + set_vector_value(desc, backend_param::FWD_STRIDES, config.input_strides); + set_vector_value(desc, backend_param::BWD_STRIDES, config.output_strides); } else { - desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data()); - desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data()); + set_vector_value(desc, backend_param::FWD_STRIDES, config.output_strides); + set_vector_value(desc, backend_param::BWD_STRIDES, config.input_strides); } } desc.set_value(backend_param::FWD_DISTANCE, config.fwd_dist); desc.set_value(backend_param::BWD_DISTANCE, config.bwd_dist); if (config.workspace_placement == dft::detail::config_value::WORKSPACE_EXTERNAL) { // Setting WORKSPACE_INTERNAL (default) causes FFT_INVALID_DESCRIPTOR. - desc.set_value(backend_param::WORKSPACE, - to_mklgpu_config_value( - config.workspace_placement)); + desc.set_value(backend_param::WORKSPACE, to_mklgpu( + config.workspace_placement)); } // Setting the ordering causes an FFT_INVALID_DESCRIPTOR. Check that default is used: if (config.ordering != dft::detail::config_value::ORDERED) { @@ -214,11 +230,11 @@ class mklgpu_commit final : public dft::detail::commit_impl { // This is called by the workspace_helper, and is not part of the user API. virtual std::int64_t get_workspace_external_bytes_impl() override { - std::size_t workspaceSizeFwd = 0, workspaceSizeBwd = 0; + std::int64_t workspaceSizeFwd = 0, workspaceSizeBwd = 0; using backend_param = oneapi::mkl::dft::config_param; handle.first->get_value(backend_param::WORKSPACE_BYTES, &workspaceSizeFwd); handle.second->get_value(backend_param::WORKSPACE_BYTES, &workspaceSizeBwd); - return static_cast(std::max(workspaceSizeFwd, workspaceSizeFwd)); + return std::max(workspaceSizeFwd, workspaceSizeFwd); } }; } // namespace detail diff --git a/src/dft/backends/mklgpu/forward.cpp b/src/dft/backends/mklgpu/forward.cpp index f4ce97b7b..4d5b9e054 100644 --- a/src/dft/backends/mklgpu/forward.cpp +++ b/src/dft/backends/mklgpu/forward.cpp @@ -62,9 +62,9 @@ inline auto compute_forward(dft::detail::descriptor& desc, ArgTs&&... } auto handle = reinterpret_cast(commit_handle->get_handle()); auto mklgpu_desc = handle->first; // First because forward DFT. - int commit_status{ DFTI_UNCOMMITTED }; + auto commit_status = uncommitted; mklgpu_desc->get_value(oneapi::mkl::dft::config_param::COMMIT_STATUS, &commit_status); - if (commit_status != DFTI_COMMITTED) { + if (commit_status != committed) { throw math::invalid_argument("DFT", "compute_forward", "MKLGPU DFT descriptor was not successfully committed."); } diff --git a/src/dft/backends/mklgpu/mklgpu_helpers.hpp b/src/dft/backends/mklgpu/mklgpu_helpers.hpp index 3be413637..3433b0fb3 100644 --- a/src/dft/backends/mklgpu/mklgpu_helpers.hpp +++ b/src/dft/backends/mklgpu/mklgpu_helpers.hpp @@ -27,8 +27,16 @@ #include #if INTEL_MKL_VERSION < 20250000 #include +namespace oneapi::math::dft::mklgpu::detail { +constexpr int committed = DFTI_COMMITTED; +constexpr int uncommitted = DFTI_UNCOMMITTED; +} // namespace oneapi::math::dft::mklgpu::detail #else #include +namespace oneapi::math::dft::mklgpu::detail { +constexpr auto committed = oneapi::mkl::dft::config_value::COMMITTED; +constexpr auto uncommitted = oneapi::mkl::dft::config_value::UNCOMMITTED; +} // namespace oneapi::math::dft::mklgpu::detail #endif namespace oneapi { @@ -57,120 +65,61 @@ inline constexpr oneapi::mkl::dft::precision to_mklgpu(dft::detail::precision do } } -/// Convert a config_param to equivalent backend native value. -inline constexpr oneapi::mkl::dft::config_param to_mklgpu(dft::detail::config_param param) { - using iparam = dft::detail::config_param; - using oparam = oneapi::mkl::dft::config_param; - switch (param) { - case iparam::FORWARD_DOMAIN: return oparam::FORWARD_DOMAIN; - case iparam::DIMENSION: return oparam::DIMENSION; - case iparam::LENGTHS: return oparam::LENGTHS; - case iparam::PRECISION: return oparam::PRECISION; - case iparam::FORWARD_SCALE: return oparam::FORWARD_SCALE; - case iparam::NUMBER_OF_TRANSFORMS: return oparam::NUMBER_OF_TRANSFORMS; - case iparam::COMPLEX_STORAGE: return oparam::COMPLEX_STORAGE; - case iparam::CONJUGATE_EVEN_STORAGE: return oparam::CONJUGATE_EVEN_STORAGE; - case iparam::FWD_DISTANCE: return oparam::FWD_DISTANCE; - case iparam::BWD_DISTANCE: return oparam::BWD_DISTANCE; - case iparam::WORKSPACE: return oparam::WORKSPACE; - case iparam::PACKED_FORMAT: return oparam::PACKED_FORMAT; - case iparam::WORKSPACE_PLACEMENT: return oparam::WORKSPACE; // Same as WORKSPACE - case iparam::WORKSPACE_EXTERNAL_BYTES: return oparam::WORKSPACE_BYTES; - case iparam::COMMIT_STATUS: return oparam::COMMIT_STATUS; - default: - throw math::invalid_argument("dft", "MKLGPU descriptor set_value()", - "Invalid config param."); - return static_cast(0); - } -} +template +struct to_mklgpu_impl; /** Convert a config_value to the backend's native value. Throw on invalid input. * @tparam Param The config param the value is for. * @param value The config value to convert. **/ template -inline constexpr int to_mklgpu(dft::detail::config_value value); - -template <> -inline constexpr int to_mklgpu( - dft::detail::config_value value) { - if (value == dft::detail::config_value::COMPLEX_COMPLEX) { - return DFTI_COMPLEX_COMPLEX; - } - else { - throw math::unimplemented("dft", "MKLGPU descriptor set_value()", - "MKLGPU only supports complex-complex for complex storage."); - return 0; - } -} - -template <> -inline constexpr int to_mklgpu( - dft::detail::config_value value) { - if (value == dft::detail::config_value::COMPLEX_COMPLEX) { - return DFTI_COMPLEX_COMPLEX; - } - else { - throw math::invalid_argument("dft", "MKLGPU descriptor set_value()", - "Invalid config value for conjugate even storage."); - return 0; - } +inline constexpr auto to_mklgpu(dft::detail::config_value value) { + return to_mklgpu_impl{}(value); } +#if INTEL_MKL_VERSION < 20250000 template <> -inline constexpr int to_mklgpu( - dft::detail::config_value value) { - if (value == dft::detail::config_value::INPLACE) { - return DFTI_INPLACE; - } - else if (value == dft::detail::config_value::NOT_INPLACE) { - return DFTI_NOT_INPLACE; - } - else { - throw math::invalid_argument("dft", "MKLGPU descriptor set_value()", - "Invalid config value for inplace."); - return 0; - } -} - +struct to_mklgpu_impl { + inline constexpr auto operator()(dft::detail::config_value value) -> int { + switch (value) { + case dft::detail::config_value::INPLACE: return DFTI_INPLACE; + case dft::detail::config_value::NOT_INPLACE: return DFTI_NOT_INPLACE; + default: + throw math::invalid_argument("dft", "MKLGPU descriptor set_value()", + "Invalid config value for inplace."); + } + } +}; +#else template <> -inline constexpr int to_mklgpu( - dft::detail::config_value value) { - if (value == dft::detail::config_value::CCE_FORMAT) { - return DFTI_CCE_FORMAT; - } - else { - throw math::invalid_argument("dft", "MKLGPU descriptor set_value()", - "Invalid config value for packed format."); - return 0; - } -} - -/** Convert a config_value to the backend's native value. Throw on invalid input. - * @tparam Param The config param the value is for. - * @param value The config value to convert. -**/ -template -inline constexpr oneapi::mkl::dft::config_value to_mklgpu_config_value( - dft::detail::config_value value); +struct to_mklgpu_impl { + inline constexpr auto operator()(dft::detail::config_value value) { + switch (value) { + case dft::detail::config_value::INPLACE: return oneapi::mkl::dft::config_value::INPLACE; + case dft::detail::config_value::NOT_INPLACE: + return oneapi::mkl::dft::config_value::NOT_INPLACE; + default: + throw math::invalid_argument("dft", "MKLGPU descriptor set_value()", + "Invalid config value for inplace."); + } + } +}; +#endif template <> -inline constexpr oneapi::mkl::dft::config_value -to_mklgpu_config_value( - dft::detail::config_value value) { - if (value == dft::detail::config_value::WORKSPACE_AUTOMATIC) { - // NB: oneapi::mkl::dft::config_value != dft::detail::config_value - return oneapi::mkl::dft::config_value::WORKSPACE_INTERNAL; - } - else if (value == dft::detail::config_value::WORKSPACE_EXTERNAL) { - return oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL; - } - else { - throw math::invalid_argument("dft", "MKLGPU descriptor set_value()", - "Invalid config value for workspace placement."); - return oneapi::mkl::dft::config_value::WORKSPACE_INTERNAL; - } -} +struct to_mklgpu_impl { + inline constexpr auto operator()(dft::detail::config_value value) { + switch (value) { + case dft::detail::config_value::WORKSPACE_AUTOMATIC: + return oneapi::mkl::dft::config_value::WORKSPACE_INTERNAL; + case dft::detail::config_value::WORKSPACE_EXTERNAL: + return oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL; + default: + throw math::invalid_argument("dft", "MKLGPU descriptor set_value()", + "Invalid config value for inplace."); + } + } +}; } // namespace detail } // namespace mklgpu } // namespace dft