Skip to content

Commit 78fac71

Browse files
hjabirdnormallytangent
authored andcommitted
[DFT] Add FWD/BWD_STRIDES to public API, deprecate INPUT/OUTPUT_STRIDES (#528)
1 parent 032ae69 commit 78fac71

File tree

18 files changed

+576
-312
lines changed

18 files changed

+576
-312
lines changed

include/oneapi/mkl/dft/detail/types_impl.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ enum class config_param {
148148

149149
PLACEMENT,
150150

151-
INPUT_STRIDES,
152-
OUTPUT_STRIDES,
151+
INPUT_STRIDES [[deprecated("Use FWD/BWD_STRIDES")]],
152+
OUTPUT_STRIDES [[deprecated("Use FWD/BWD_STRIDES")]],
153153

154154
FWD_DISTANCE,
155155
BWD_DISTANCE,
@@ -160,7 +160,10 @@ enum class config_param {
160160
ORDERING,
161161
TRANSPOSE,
162162
PACKED_FORMAT,
163-
COMMIT_STATUS
163+
COMMIT_STATUS,
164+
165+
FWD_STRIDES,
166+
BWD_STRIDES
164167
};
165168

166169
enum class config_value {
@@ -204,6 +207,8 @@ class dft_values {
204207
public:
205208
std::vector<std::int64_t> input_strides;
206209
std::vector<std::int64_t> output_strides;
210+
std::vector<std::int64_t> fwd_strides;
211+
std::vector<std::int64_t> bwd_strides;
207212
real_t bwd_scale;
208213
real_t fwd_scale;
209214
std::int64_t number_of_transforms;

src/dft/backends/cufft/backward.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ namespace oneapi::mkl::dft::cufft {
3737
namespace detail {
3838
//forward declaration
3939
template <dft::precision prec, dft::domain dom>
40-
std::array<std::int64_t, 2> get_offsets(dft::detail::commit_impl<prec, dom> *commit);
40+
std::array<std::int64_t, 2> get_offsets_bwd(dft::detail::commit_impl<prec, dom> *commit);
4141

4242
template <dft::precision prec, dft::domain dom>
4343
cufftHandle get_bwd_plan(dft::detail::commit_impl<prec, dom> *commit) {
@@ -56,7 +56,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
5656
auto commit = detail::checked_get_commit(desc);
5757
auto queue = commit->get_queue();
5858
auto plan = detail::get_bwd_plan(commit);
59-
auto offsets = detail::get_offsets(commit);
59+
auto offsets = detail::get_offsets_bwd(commit);
6060

6161
if constexpr (std::is_floating_point_v<fwd<descriptor_type>>) {
6262
offsets[0] *= 2; // offset is supplied in complex but we offset scalar pointer
@@ -102,7 +102,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
102102
auto commit = detail::checked_get_commit(desc);
103103
auto queue = commit->get_queue();
104104
auto plan = detail::get_bwd_plan(commit);
105-
auto offsets = detail::get_offsets(commit);
105+
auto offsets = detail::get_offsets_bwd(commit);
106106

107107
if constexpr (std::is_floating_point_v<fwd<descriptor_type>>) {
108108
if (offsets[1] % 2 != 0) {
@@ -156,7 +156,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwd<descriptor
156156
auto commit = detail::checked_get_commit(desc);
157157
auto queue = commit->get_queue();
158158
auto plan = detail::get_bwd_plan(commit);
159-
auto offsets = detail::get_offsets(commit);
159+
auto offsets = detail::get_offsets_bwd(commit);
160160

161161
if constexpr (std::is_floating_point_v<fwd<descriptor_type>>) {
162162
offsets[0] *= 2; // offset is supplied in complex but we offset scalar pointer
@@ -203,7 +203,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwd<descriptor
203203
auto commit = detail::checked_get_commit(desc);
204204
auto queue = commit->get_queue();
205205
auto plan = detail::get_bwd_plan(commit);
206-
auto offsets = detail::get_offsets(commit);
206+
auto offsets = detail::get_offsets_bwd(commit);
207207

208208
if constexpr (std::is_floating_point_v<fwd<descriptor_type>>) {
209209
if (offsets[1] % 2 != 0) {

src/dft/backends/cufft/commit.cpp

Lines changed: 104 additions & 99 deletions
Large diffs are not rendered by default.

src/dft/backends/cufft/forward.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace oneapi::mkl::dft::cufft {
3939
namespace detail {
4040
//forward declaration
4141
template <dft::precision prec, dft::domain dom>
42-
std::array<std::int64_t, 2> get_offsets(dft::detail::commit_impl<prec, dom> *commit);
42+
std::array<std::int64_t, 2> get_offsets_fwd(dft::detail::commit_impl<prec, dom> *commit);
4343

4444
template <dft::precision prec, dft::domain dom>
4545
cufftHandle get_fwd_plan(dft::detail::commit_impl<prec, dom> *commit) {
@@ -59,7 +59,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc,
5959
auto commit = detail::checked_get_commit(desc);
6060
auto queue = commit->get_queue();
6161
auto plan = detail::get_fwd_plan(commit);
62-
auto offsets = detail::get_offsets(commit);
62+
auto offsets = detail::get_offsets_fwd(commit);
6363

6464
if constexpr (std::is_floating_point_v<fwd<descriptor_type>>) {
6565
if (offsets[0] % 2 != 0) {
@@ -104,7 +104,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer<fwd<descr
104104
auto commit = detail::checked_get_commit(desc);
105105
auto queue = commit->get_queue();
106106
auto plan = detail::get_fwd_plan(commit);
107-
auto offsets = detail::get_offsets(commit);
107+
auto offsets = detail::get_offsets_fwd(commit);
108108

109109
if constexpr (std::is_floating_point_v<fwd<descriptor_type>>) {
110110
if (offsets[0] % 2 != 0) {
@@ -158,7 +158,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd<descriptor_
158158
auto commit = detail::checked_get_commit(desc);
159159
auto queue = commit->get_queue();
160160
auto plan = detail::get_fwd_plan(commit);
161-
auto offsets = detail::get_offsets(commit);
161+
auto offsets = detail::get_offsets_fwd(commit);
162162

163163
if constexpr (std::is_floating_point_v<fwd<descriptor_type>>) {
164164
if (offsets[0] % 2 != 0) {
@@ -205,7 +205,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd<descriptor_
205205
auto commit = detail::checked_get_commit(desc);
206206
auto queue = commit->get_queue();
207207
auto plan = detail::get_fwd_plan(commit);
208-
auto offsets = detail::get_offsets(commit);
208+
auto offsets = detail::get_offsets_fwd(commit);
209209

210210
if constexpr (std::is_floating_point_v<fwd<descriptor_type>>) {
211211
if (offsets[0] % 2 != 0) {

src/dft/backends/mklcpu/commit.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "oneapi/mkl/dft/detail/commit_impl.hpp"
3434

3535
#include "dft/backends/mklcpu/commit_derived_impl.hpp"
36+
#include "../stride_helper.hpp"
3637
#include "mkl_service.h"
3738
#include "mkl_dfti.h"
3839

@@ -129,9 +130,23 @@ void commit_derived_impl<prec, dom>::set_value_item(mklcpu_desc_t hand, enum DFT
129130
template <dft::detail::precision prec, dft::detail::domain dom>
130131
void commit_derived_impl<prec, dom>::set_value(mklcpu_desc_t* descHandle,
131132
const dft::detail::dft_values<prec, dom>& config) {
133+
auto stride_choice = dft::detail::get_stride_api(config);
134+
dft::detail::throw_on_invalid_stride_api("MKLCPU commit", stride_choice);
132135
for (auto dir : { DIR::fwd, DIR::bwd }) {
133-
set_value_item(descHandle[dir], DFTI_INPUT_STRIDES, config.input_strides.data());
134-
set_value_item(descHandle[dir], DFTI_OUTPUT_STRIDES, config.output_strides.data());
136+
if (stride_choice == dft::detail::stride_api::IO_STRIDES) {
137+
set_value_item(descHandle[dir], DFTI_INPUT_STRIDES, config.input_strides.data());
138+
set_value_item(descHandle[dir], DFTI_OUTPUT_STRIDES, config.output_strides.data());
139+
}
140+
else { // Forward / backward strides
141+
if (dir == DIR::fwd) {
142+
set_value_item(descHandle[dir], DFTI_INPUT_STRIDES, config.fwd_strides.data());
143+
set_value_item(descHandle[dir], DFTI_OUTPUT_STRIDES, config.bwd_strides.data());
144+
}
145+
else {
146+
set_value_item(descHandle[dir], DFTI_INPUT_STRIDES, config.bwd_strides.data());
147+
set_value_item(descHandle[dir], DFTI_OUTPUT_STRIDES, config.fwd_strides.data());
148+
}
149+
}
135150
set_value_item(descHandle[dir], DFTI_BACKWARD_SCALE, config.bwd_scale);
136151
set_value_item(descHandle[dir], DFTI_FORWARD_SCALE, config.fwd_scale);
137152
set_value_item(descHandle[dir], DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms);

src/dft/backends/mklcpu/mklcpu_helpers.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ inline constexpr DFTI_CONFIG_PARAM to_mklcpu(dft::detail::config_param param) {
7474
case iparam::COMPLEX_STORAGE: return DFTI_COMPLEX_STORAGE;
7575
case iparam::REAL_STORAGE: return DFTI_REAL_STORAGE;
7676
case iparam::CONJUGATE_EVEN_STORAGE: return DFTI_CONJUGATE_EVEN_STORAGE;
77-
case iparam::INPUT_STRIDES: return DFTI_INPUT_STRIDES;
78-
case iparam::OUTPUT_STRIDES: return DFTI_OUTPUT_STRIDES;
7977
case iparam::FWD_DISTANCE: return DFTI_FWD_DISTANCE;
8078
case iparam::BWD_DISTANCE: return DFTI_BWD_DISTANCE;
8179
case iparam::WORKSPACE: return DFTI_WORKSPACE;

src/dft/backends/mklgpu/commit.cpp

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp"
3434

3535
#include "dft/backends/mklgpu/mklgpu_helpers.hpp"
36+
#include "../stride_helper.hpp"
3637

3738
// MKLGPU header
3839
#include "oneapi/mkl/dfti.hpp"
@@ -90,14 +91,18 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
9091
config_values.workspace_placement ==
9192
oneapi::mkl::dft::detail::config_value::WORKSPACE_EXTERNAL);
9293

94+
auto stride_choice = dft::detail::get_stride_api(config_values);
95+
throw_on_invalid_stride_api("MKLGPU commit", stride_choice);
9396
// A separate descriptor for each direction may not be required.
94-
bool one_descriptor = config_values.input_strides == config_values.output_strides;
97+
bool one_descriptor = (stride_choice == dft::detail::stride_api::FB_STRIDES) ||
98+
(config_values.input_strides == config_values.output_strides);
9599
bool forward_good = true;
96-
// Make sure that second is always pointing to something new if this is a recommit.
100+
// Make sure that second is always pointing to something new if this is a recommit.
97101
handle.second = handle.first;
98102

99-
// Generate forward DFT descriptor.
100-
set_value(*handle.first, config_values, true);
103+
// Generate forward DFT descriptor. If using FWD/BWD_STRIDES API, only
104+
// one descriptor is needed.
105+
set_value(*handle.first, config_values, true, stride_choice);
101106
try {
102107
handle.first->commit(this->get_queue());
103108
}
@@ -114,7 +119,7 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
114119
// Generate backward DFT descriptor only if required.
115120
if (!one_descriptor) {
116121
handle.second = std::make_shared<mklgpu_descriptor_t>(config_values.dimensions);
117-
set_value(*handle.second, config_values, false);
122+
set_value(*handle.second, config_values, false, stride_choice);
118123
try {
119124
handle.second->commit(this->get_queue());
120125
}
@@ -160,7 +165,7 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
160165
handle_t handle;
161166

162167
void set_value(mklgpu_descriptor_t& desc, const dft::detail::dft_values<prec, dom>& config,
163-
bool assume_fwd_dft) {
168+
bool assume_fwd_dft, dft::detail::stride_api stride_choice) {
164169
using onemkl_param = dft::detail::config_param;
165170
using backend_param = dft::config_param;
166171

@@ -181,17 +186,27 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
181186
desc.set_value(backend_param::PLACEMENT,
182187
to_mklgpu<onemkl_param::PLACEMENT>(config.placement));
183188

184-
if (config.input_strides[0] != 0 || config.output_strides[0] != 0) {
185-
throw mkl::unimplemented("dft/backends/mklgpu", "commit",
186-
"MKLGPU does not support nonzero offsets.");
187-
}
188-
if (assume_fwd_dft) {
189-
desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data());
190-
desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data());
189+
if (stride_choice == dft::detail::stride_api::FB_STRIDES) {
190+
if (config.fwd_strides[0] != 0 || config.fwd_strides[0] != 0) {
191+
throw mkl::unimplemented("dft/backends/mklgpu", "commit",
192+
"MKLGPU does not support nonzero offsets.");
193+
}
194+
desc.set_value(backend_param::FWD_STRIDES, config.fwd_strides.data());
195+
desc.set_value(backend_param::BWD_STRIDES, config.bwd_strides.data());
191196
}
192197
else {
193-
desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data());
194-
desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data());
198+
if (config.input_strides[0] != 0 || config.output_strides[0] != 0) {
199+
throw mkl::unimplemented("dft/backends/mklgpu", "commit",
200+
"MKLGPU does not support nonzero offsets.");
201+
}
202+
if (assume_fwd_dft) {
203+
desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data());
204+
desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data());
205+
}
206+
else {
207+
desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data());
208+
desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data());
209+
}
195210
}
196211
desc.set_value(backend_param::FWD_DISTANCE, config.fwd_dist);
197212
desc.set_value(backend_param::BWD_DISTANCE, config.bwd_dist);

src/dft/backends/mklgpu/mklgpu_helpers.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ inline constexpr dft::config_param to_mklgpu(dft::detail::config_param param) {
6666
case iparam::COMPLEX_STORAGE: return oparam::COMPLEX_STORAGE;
6767
case iparam::REAL_STORAGE: return oparam::REAL_STORAGE;
6868
case iparam::CONJUGATE_EVEN_STORAGE: return oparam::CONJUGATE_EVEN_STORAGE;
69-
case iparam::INPUT_STRIDES: return oparam::INPUT_STRIDES;
70-
case iparam::OUTPUT_STRIDES: return oparam::OUTPUT_STRIDES;
7169
case iparam::FWD_DISTANCE: return oparam::FWD_DISTANCE;
7270
case iparam::BWD_DISTANCE: return oparam::BWD_DISTANCE;
7371
case iparam::WORKSPACE: return oparam::WORKSPACE;

src/dft/backends/portfft/commit.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
#include "oneapi/mkl/dft/detail/portfft/onemkl_dft_portfft.hpp"
3636
#include "oneapi/mkl/dft/types.hpp"
3737

38+
#include "../stride_helper.hpp"
39+
3840
#include "portfft_helper.hpp"
3941

4042
// alias to avoid ambiguity
@@ -87,6 +89,10 @@ class portfft_commit final : public dft::detail::commit_impl<prec, dom> {
8789
"portFFT does not supported transposed output");
8890
}
8991

92+
auto stride_api_choice = dft::detail::get_stride_api(config_values);
93+
dft::detail::throw_on_invalid_stride_api("portFFT commit", stride_api_choice);
94+
dft::detail::stride_vectors<std::int64_t> stride_vecs(config_values, stride_api_choice);
95+
9096
// forward descriptor
9197
pfft::descriptor<scalar_type, domain> fwd_desc(
9298
{ config_values.dimensions.cbegin(), config_values.dimensions.cend() });
@@ -100,12 +106,11 @@ class portfft_commit final : public dft::detail::commit_impl<prec, dom> {
100106
fwd_desc.placement = config_values.placement == config_value::INPLACE
101107
? pfft::placement::IN_PLACE
102108
: pfft::placement::OUT_OF_PLACE;
103-
fwd_desc.forward_offset = static_cast<std::size_t>(config_values.input_strides[0]);
104-
fwd_desc.backward_offset = static_cast<std::size_t>(config_values.output_strides[0]);
105-
fwd_desc.forward_strides = { config_values.input_strides.cbegin() + 1,
106-
config_values.input_strides.cend() };
107-
fwd_desc.backward_strides = { config_values.output_strides.cbegin() + 1,
108-
config_values.output_strides.cend() };
109+
fwd_desc.forward_offset = static_cast<std::size_t>(stride_vecs.offset_fwd_in);
110+
fwd_desc.backward_offset = static_cast<std::size_t>(stride_vecs.offset_fwd_out);
111+
fwd_desc.forward_strides = { stride_vecs.fwd_in.cbegin() + 1, stride_vecs.fwd_in.cend() };
112+
fwd_desc.backward_strides = { stride_vecs.fwd_out.cbegin() + 1,
113+
stride_vecs.fwd_out.cend() };
109114
fwd_desc.forward_distance = static_cast<std::size_t>(config_values.fwd_dist);
110115
fwd_desc.backward_distance = static_cast<std::size_t>(config_values.bwd_dist);
111116

@@ -122,12 +127,11 @@ class portfft_commit final : public dft::detail::commit_impl<prec, dom> {
122127
bwd_desc.placement = config_values.placement == config_value::INPLACE
123128
? pfft::placement::IN_PLACE
124129
: pfft::placement::OUT_OF_PLACE;
125-
bwd_desc.forward_offset = static_cast<std::size_t>(config_values.output_strides[0]);
126-
bwd_desc.backward_offset = static_cast<std::size_t>(config_values.input_strides[0]);
127-
bwd_desc.forward_strides = { config_values.output_strides.cbegin() + 1,
128-
config_values.output_strides.cend() };
129-
bwd_desc.backward_strides = { config_values.input_strides.cbegin() + 1,
130-
config_values.input_strides.cend() };
130+
bwd_desc.forward_offset = static_cast<std::size_t>(stride_vecs.offset_bwd_out);
131+
bwd_desc.backward_offset = static_cast<std::size_t>(stride_vecs.offset_bwd_in);
132+
bwd_desc.forward_strides = { stride_vecs.bwd_out.cbegin() + 1, stride_vecs.bwd_out.cend() };
133+
bwd_desc.backward_strides = { stride_vecs.bwd_in.cbegin() + 1,
134+
stride_vecs.bwd_in.cend() };
131135
bwd_desc.forward_distance = static_cast<std::size_t>(config_values.fwd_dist);
132136
bwd_desc.backward_distance = static_cast<std::size_t>(config_values.bwd_dist);
133137

0 commit comments

Comments
 (0)