3333#include " oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp"
3434
3535#include " dft/backends/mklgpu/mklgpu_helpers.hpp"
36+ #include " ../stride_helper.hpp"
3637
3738// MKLGPU header
3839#include " oneapi/mkl/dfti.hpp"
@@ -90,14 +91,18 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
9091 config_values.workspace_placement ==
9192 oneapi::mkl::dft::detail::config_value::WORKSPACE_EXTERNAL);
9293
94+ auto stride_choice = dft::detail::get_stride_api (config_values);
95+ throw_on_invalid_stride_api (" MKLGPU commit" , stride_choice);
9396 // A separate descriptor for each direction may not be required.
94- bool one_descriptor = config_values.input_strides == config_values.output_strides ;
97+ bool one_descriptor = (stride_choice == dft::detail::stride_api::FB_STRIDES) ||
98+ (config_values.input_strides == config_values.output_strides );
9599 bool forward_good = true ;
96- // Make sure that second is always pointing to something new if this is a recommit.
100+ // Make sure that second is always pointing to something new if this is a recommit.
97101 handle.second = handle.first ;
98102
99- // Generate forward DFT descriptor.
100- set_value (*handle.first , config_values, true );
103+ // Generate forward DFT descriptor. If using FWD/BWD_STRIDES API, only
104+ // one descriptor is needed.
105+ set_value (*handle.first , config_values, true , stride_choice);
101106 try {
102107 handle.first ->commit (this ->get_queue ());
103108 }
@@ -114,7 +119,7 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
114119 // Generate backward DFT descriptor only if required.
115120 if (!one_descriptor) {
116121 handle.second = std::make_shared<mklgpu_descriptor_t >(config_values.dimensions );
117- set_value (*handle.second , config_values, false );
122+ set_value (*handle.second , config_values, false , stride_choice );
118123 try {
119124 handle.second ->commit (this ->get_queue ());
120125 }
@@ -160,7 +165,7 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
160165 handle_t handle;
161166
162167 void set_value (mklgpu_descriptor_t & desc, const dft::detail::dft_values<prec, dom>& config,
163- bool assume_fwd_dft) {
168+ bool assume_fwd_dft, dft::detail::stride_api stride_choice ) {
164169 using onemkl_param = dft::detail::config_param;
165170 using backend_param = dft::config_param;
166171
@@ -181,17 +186,27 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
181186 desc.set_value (backend_param::PLACEMENT,
182187 to_mklgpu<onemkl_param::PLACEMENT>(config.placement ));
183188
184- if (config. input_strides [ 0 ] != 0 || config. output_strides [ 0 ] != 0 ) {
185- throw mkl::unimplemented ( " dft/backends/mklgpu " , " commit " ,
186- " MKLGPU does not support nonzero offsets. " );
187- }
188- if (assume_fwd_dft) {
189- desc.set_value (backend_param::FWD_STRIDES, config.input_strides .data ());
190- desc.set_value (backend_param::BWD_STRIDES, config.output_strides .data ());
189+ if (stride_choice == dft::detail::stride_api::FB_STRIDES ) {
190+ if (config. fwd_strides [ 0 ] != 0 || config. fwd_strides [ 0 ] != 0 ) {
191+ throw mkl::unimplemented ( " dft/backends/mklgpu " , " commit " ,
192+ " MKLGPU does not support nonzero offsets. " );
193+ }
194+ desc.set_value (backend_param::FWD_STRIDES, config.fwd_strides .data ());
195+ desc.set_value (backend_param::BWD_STRIDES, config.bwd_strides .data ());
191196 }
192197 else {
193- desc.set_value (backend_param::FWD_STRIDES, config.output_strides .data ());
194- desc.set_value (backend_param::BWD_STRIDES, config.input_strides .data ());
198+ if (config.input_strides [0 ] != 0 || config.output_strides [0 ] != 0 ) {
199+ throw mkl::unimplemented (" dft/backends/mklgpu" , " commit" ,
200+ " MKLGPU does not support nonzero offsets." );
201+ }
202+ if (assume_fwd_dft) {
203+ desc.set_value (backend_param::FWD_STRIDES, config.input_strides .data ());
204+ desc.set_value (backend_param::BWD_STRIDES, config.output_strides .data ());
205+ }
206+ else {
207+ desc.set_value (backend_param::FWD_STRIDES, config.output_strides .data ());
208+ desc.set_value (backend_param::BWD_STRIDES, config.input_strides .data ());
209+ }
195210 }
196211 desc.set_value (backend_param::FWD_DISTANCE, config.fwd_dist );
197212 desc.set_value (backend_param::BWD_DISTANCE, config.bwd_dist );
0 commit comments