3737// MKLGPU header
3838#include " oneapi/mkl/dfti.hpp"
3939
40+ // MKL 2024.1 deprecates input/output strides.
41+ #include " mkl_version.h"
42+ #if INTEL_MKL_VERSION < 20240001
43+ #error MKLGPU requires oneMKL 2024.1 or later
44+ #endif
45+
4046/* *
4147Note that in this file, the Intel oneMKL closed-source library's interface mirrors the interface
4248of this OneMKL open-source library. Consequently, the types under dft::TYPE are closed-source oneMKL types,
@@ -53,14 +59,22 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
5359 // Equivalent MKLGPU precision and domain from OneMKL's precision / domain.
5460 static constexpr dft::precision mklgpu_prec = to_mklgpu(prec);
5561 static constexpr dft::domain mklgpu_dom = to_mklgpu(dom);
62+
63+ // A pair of descriptors are needed because of the [[deprecated]]IN/OUTPUT_STRIDES vs F/BWD_STRIDES API.
64+ // Of the pair [0] is fwd DFT, [1] is backward DFT. If possible, the pointers refer to the same desciptor.
65+ // Both pointers must be valid.
5666 using mklgpu_descriptor_t = dft::descriptor<mklgpu_prec, mklgpu_dom>;
67+ using descriptor_shptr_t = std::shared_ptr<mklgpu_descriptor_t >;
68+ using handle_t = std::pair<descriptor_shptr_t , descriptor_shptr_t >;
69+
5770 using scalar_type = typename dft::detail::commit_impl<prec, dom>::scalar_type;
5871
5972public:
6073 mklgpu_commit (sycl::queue queue, const dft::detail::dft_values<prec, dom>& config_values)
6174 : oneapi::mkl::dft::detail::commit_impl<prec, dom>(queue, backend::mklgpu,
6275 config_values),
63- handle (config_values.dimensions) {
76+ handle (std::make_shared<mklgpu_descriptor_t >(config_values.dimensions), nullptr ) {
77+ handle.second = handle.first ; // Make sure the bwd pointer is valid.
6478 // MKLGPU does not throw an informative exception for the following:
6579 if constexpr (prec == dft::detail::precision::DOUBLE) {
6680 if (!queue.get_device ().has (sycl::aspect::fp64)) {
@@ -75,13 +89,43 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
7589 oneapi::mkl::dft::detail::external_workspace_helper<prec, dom>(
7690 config_values.workspace_placement ==
7791 oneapi::mkl::dft::detail::config_value::WORKSPACE_EXTERNAL);
78- set_value (handle, config_values);
92+
93+ // A separate descriptor for each direction may not be required.
94+ bool one_descriptor = config_values.input_strides == config_values.output_strides ;
95+ bool forward_good = true ;
96+ // Make sure that second is always pointing to something new if this is a recommit.
97+ handle.second = handle.first ;
98+
99+ // Generate forward DFT descriptor.
100+ set_value (*handle.first , config_values, true );
79101 try {
80- handle.commit (this ->get_queue ());
102+ handle.first -> commit (this ->get_queue ());
81103 }
82104 catch (const std::exception& mkl_exception) {
83- // Catching the real Intel oneMKL exception causes headaches with naming.
84- throw mkl::exception (" dft/backends/mklgpu" , " commit" , mkl_exception.what ());
105+ // Catching the real Intel oneMKL exception causes headaches with naming
106+ forward_good = false ;
107+ if (one_descriptor) {
108+ throw mkl::exception (" dft/backends/mklgpu"
109+ " commit" ,
110+ mkl_exception.what ());
111+ }
112+ }
113+
114+ // Generate backward DFT descriptor only if required.
115+ if (!one_descriptor) {
116+ handle.second = std::make_shared<mklgpu_descriptor_t >(config_values.dimensions );
117+ set_value (*handle.second , config_values, false );
118+ try {
119+ handle.second ->commit (this ->get_queue ());
120+ }
121+ catch (const std::exception& mkl_exception) {
122+ // Catching the real Intel oneMKL exception causes headaches with naming.
123+ if (!forward_good) {
124+ throw mkl::exception (" dft/backends/mklgpu"
125+ " commit" ,
126+ mkl_exception.what ());
127+ }
128+ }
85129 }
86130 }
87131
@@ -93,12 +137,18 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
93137
94138 virtual void set_workspace (scalar_type* usm_workspace) override {
95139 this ->external_workspace_helper_ .set_workspace_throw (*this , usm_workspace);
96- handle.set_workspace (usm_workspace);
140+ handle.first ->set_workspace (usm_workspace);
141+ if (handle.first != handle.second ) {
142+ handle.second ->set_workspace (usm_workspace);
143+ }
97144 }
98145
99146 virtual void set_workspace (sycl::buffer<scalar_type>& buffer_workspace) override {
100147 this ->external_workspace_helper_ .set_workspace_throw (*this , buffer_workspace);
101- handle.set_workspace (buffer_workspace);
148+ handle.first ->set_workspace (buffer_workspace);
149+ if (handle.first != handle.second ) {
150+ handle.second ->set_workspace (buffer_workspace);
151+ }
102152 }
103153
104154#define BACKEND mklgpu
@@ -107,9 +157,10 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
107157
108158private:
109159 // The native MKLGPU class.
110- mklgpu_descriptor_t handle;
160+ handle_t handle;
111161
112- void set_value (mklgpu_descriptor_t & desc, const dft::detail::dft_values<prec, dom>& config) {
162+ void set_value (mklgpu_descriptor_t & desc, const dft::detail::dft_values<prec, dom>& config,
163+ bool assume_fwd_dft) {
113164 using onemkl_param = dft::detail::config_param;
114165 using backend_param = dft::config_param;
115166
@@ -134,8 +185,14 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
134185 throw mkl::unimplemented (" dft/backends/mklgpu" , " commit" ,
135186 " MKLGPU does not support nonzero offsets." );
136187 }
137- desc.set_value (backend_param::INPUT_STRIDES, config.input_strides .data ());
138- desc.set_value (backend_param::OUTPUT_STRIDES, config.output_strides .data ());
188+ if (assume_fwd_dft) {
189+ desc.set_value (backend_param::FWD_STRIDES, config.input_strides .data ());
190+ desc.set_value (backend_param::BWD_STRIDES, config.output_strides .data ());
191+ }
192+ else {
193+ desc.set_value (backend_param::FWD_STRIDES, config.output_strides .data ());
194+ desc.set_value (backend_param::BWD_STRIDES, config.input_strides .data ());
195+ }
139196 desc.set_value (backend_param::FWD_DISTANCE, config.fwd_dist );
140197 desc.set_value (backend_param::BWD_DISTANCE, config.bwd_dist );
141198 if (config.workspace_placement == dft::detail::config_value::WORKSPACE_EXTERNAL) {
@@ -158,9 +215,10 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
158215
159216 // This is called by the workspace_helper, and is not part of the user API.
160217 virtual std::int64_t get_workspace_external_bytes_impl () override {
161- std::size_t workspaceSize = 0 ;
162- handle.get_value (dft::config_param::WORKSPACE_BYTES, &workspaceSize);
163- return static_cast <std::int64_t >(workspaceSize);
218+ std::size_t workspaceSizeFwd = 0 , workspaceSizeBwd = 0 ;
219+ handle.first ->get_value (dft::config_param::WORKSPACE_BYTES, &workspaceSizeFwd);
220+ handle.second ->get_value (dft::config_param::WORKSPACE_BYTES, &workspaceSizeBwd);
221+ return static_cast <std::int64_t >(std::max (workspaceSizeFwd, workspaceSizeFwd));
164222 }
165223};
166224} // namespace detail
0 commit comments