Skip to content

Commit c248bc2

Browse files
hjabirdnormallytangent
authored andcommitted
[DFT][MKLGPU] Use FWD/BWD_STRIDES (#514)
* INPUT/OUTPUT_STRIDES are deprecated in the oneMKL spec and Intel(R) oneMLK 2024.1 (the latest release) * Using them causes a warning message to be printed * This PR uses the new API: FWD/BWD_STRIDES
1 parent b08b41f commit c248bc2

File tree

4 files changed

+81
-17
lines changed

4 files changed

+81
-17
lines changed

src/dft/backends/cufft/commit.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ class cufft_commit final : public dft::detail::commit_impl<prec, dom> {
394394

395395
std::int64_t get_plan_workspace_size_bytes(cufftHandle handle) {
396396
std::size_t size = 0;
397-
cufftGetSize(*plans[0], &size);
397+
cufftGetSize(handle, &size);
398398
std::int64_t padded_size = static_cast<int64_t>(size);
399399
return padded_size;
400400
}

src/dft/backends/mklgpu/backward.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,15 @@ namespace detail {
4141
template <dft::detail::precision prec, dft::detail::domain dom, typename... ArgTs>
4242
inline auto compute_backward(dft::detail::descriptor<prec, dom> &desc, ArgTs &&... args) {
4343
using mklgpu_desc_t = dft::descriptor<to_mklgpu(prec), to_mklgpu(dom)>;
44+
using desc_shptr_t = std::shared_ptr<mklgpu_desc_t>;
45+
using handle_t = std::pair<desc_shptr_t, desc_shptr_t>;
4446
auto commit_handle = dft::detail::get_commit(desc);
4547
if (commit_handle == nullptr || commit_handle->get_backend() != backend::mklgpu) {
4648
throw mkl::invalid_argument("DFT", "compute_backward",
4749
"DFT descriptor has not been commited for MKLGPU");
4850
}
49-
auto mklgpu_desc = reinterpret_cast<mklgpu_desc_t *>(commit_handle->get_handle());
51+
auto handle = reinterpret_cast<handle_t *>(commit_handle->get_handle());
52+
auto mklgpu_desc = handle->second; // Second because backward DFT.
5053
int commit_status{ DFTI_UNCOMMITTED };
5154
mklgpu_desc->get_value(dft::config_param::COMMIT_STATUS, &commit_status);
5255
if (commit_status != DFTI_COMMITTED) {

src/dft/backends/mklgpu/commit.cpp

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@
3737
// MKLGPU header
3838
#include "oneapi/mkl/dfti.hpp"
3939

40+
// MKL 2024.1 deprecates input/output strides.
41+
#include "mkl_version.h"
42+
#if INTEL_MKL_VERSION < 20240001
43+
#error MKLGPU requires oneMKL 2024.1 or later
44+
#endif
45+
4046
/**
4147
Note that in this file, the Intel oneMKL closed-source library's interface mirrors the interface
4248
of this OneMKL open-source library. Consequently, the types under dft::TYPE are closed-source oneMKL types,
@@ -53,14 +59,22 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
5359
// Equivalent MKLGPU precision and domain from OneMKL's precision / domain.
5460
static constexpr dft::precision mklgpu_prec = to_mklgpu(prec);
5561
static constexpr dft::domain mklgpu_dom = to_mklgpu(dom);
62+
63+
// A pair of descriptors are needed because of the [[deprecated]]IN/OUTPUT_STRIDES vs F/BWD_STRIDES API.
64+
// Of the pair [0] is fwd DFT, [1] is backward DFT. If possible, the pointers refer to the same desciptor.
65+
// Both pointers must be valid.
5666
using mklgpu_descriptor_t = dft::descriptor<mklgpu_prec, mklgpu_dom>;
67+
using descriptor_shptr_t = std::shared_ptr<mklgpu_descriptor_t>;
68+
using handle_t = std::pair<descriptor_shptr_t, descriptor_shptr_t>;
69+
5770
using scalar_type = typename dft::detail::commit_impl<prec, dom>::scalar_type;
5871

5972
public:
6073
mklgpu_commit(sycl::queue queue, const dft::detail::dft_values<prec, dom>& config_values)
6174
: oneapi::mkl::dft::detail::commit_impl<prec, dom>(queue, backend::mklgpu,
6275
config_values),
63-
handle(config_values.dimensions) {
76+
handle(std::make_shared<mklgpu_descriptor_t>(config_values.dimensions), nullptr) {
77+
handle.second = handle.first; // Make sure the bwd pointer is valid.
6478
// MKLGPU does not throw an informative exception for the following:
6579
if constexpr (prec == dft::detail::precision::DOUBLE) {
6680
if (!queue.get_device().has(sycl::aspect::fp64)) {
@@ -75,13 +89,43 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
7589
oneapi::mkl::dft::detail::external_workspace_helper<prec, dom>(
7690
config_values.workspace_placement ==
7791
oneapi::mkl::dft::detail::config_value::WORKSPACE_EXTERNAL);
78-
set_value(handle, config_values);
92+
93+
// A separate descriptor for each direction may not be required.
94+
bool one_descriptor = config_values.input_strides == config_values.output_strides;
95+
bool forward_good = true;
96+
// Make sure that second is always pointing to something new if this is a recommit.
97+
handle.second = handle.first;
98+
99+
// Generate forward DFT descriptor.
100+
set_value(*handle.first, config_values, true);
79101
try {
80-
handle.commit(this->get_queue());
102+
handle.first->commit(this->get_queue());
81103
}
82104
catch (const std::exception& mkl_exception) {
83-
// Catching the real Intel oneMKL exception causes headaches with naming.
84-
throw mkl::exception("dft/backends/mklgpu", "commit", mkl_exception.what());
105+
// Catching the real Intel oneMKL exception causes headaches with naming
106+
forward_good = false;
107+
if (one_descriptor) {
108+
throw mkl::exception("dft/backends/mklgpu"
109+
"commit",
110+
mkl_exception.what());
111+
}
112+
}
113+
114+
// Generate backward DFT descriptor only if required.
115+
if (!one_descriptor) {
116+
handle.second = std::make_shared<mklgpu_descriptor_t>(config_values.dimensions);
117+
set_value(*handle.second, config_values, false);
118+
try {
119+
handle.second->commit(this->get_queue());
120+
}
121+
catch (const std::exception& mkl_exception) {
122+
// Catching the real Intel oneMKL exception causes headaches with naming.
123+
if (!forward_good) {
124+
throw mkl::exception("dft/backends/mklgpu"
125+
"commit",
126+
mkl_exception.what());
127+
}
128+
}
85129
}
86130
}
87131

@@ -93,12 +137,18 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
93137

94138
virtual void set_workspace(scalar_type* usm_workspace) override {
95139
this->external_workspace_helper_.set_workspace_throw(*this, usm_workspace);
96-
handle.set_workspace(usm_workspace);
140+
handle.first->set_workspace(usm_workspace);
141+
if (handle.first != handle.second) {
142+
handle.second->set_workspace(usm_workspace);
143+
}
97144
}
98145

99146
virtual void set_workspace(sycl::buffer<scalar_type>& buffer_workspace) override {
100147
this->external_workspace_helper_.set_workspace_throw(*this, buffer_workspace);
101-
handle.set_workspace(buffer_workspace);
148+
handle.first->set_workspace(buffer_workspace);
149+
if (handle.first != handle.second) {
150+
handle.second->set_workspace(buffer_workspace);
151+
}
102152
}
103153

104154
#define BACKEND mklgpu
@@ -107,9 +157,10 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
107157

108158
private:
109159
// The native MKLGPU class.
110-
mklgpu_descriptor_t handle;
160+
handle_t handle;
111161

112-
void set_value(mklgpu_descriptor_t& desc, const dft::detail::dft_values<prec, dom>& config) {
162+
void set_value(mklgpu_descriptor_t& desc, const dft::detail::dft_values<prec, dom>& config,
163+
bool assume_fwd_dft) {
113164
using onemkl_param = dft::detail::config_param;
114165
using backend_param = dft::config_param;
115166

@@ -134,8 +185,14 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
134185
throw mkl::unimplemented("dft/backends/mklgpu", "commit",
135186
"MKLGPU does not support nonzero offsets.");
136187
}
137-
desc.set_value(backend_param::INPUT_STRIDES, config.input_strides.data());
138-
desc.set_value(backend_param::OUTPUT_STRIDES, config.output_strides.data());
188+
if (assume_fwd_dft) {
189+
desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data());
190+
desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data());
191+
}
192+
else {
193+
desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data());
194+
desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data());
195+
}
139196
desc.set_value(backend_param::FWD_DISTANCE, config.fwd_dist);
140197
desc.set_value(backend_param::BWD_DISTANCE, config.bwd_dist);
141198
if (config.workspace_placement == dft::detail::config_value::WORKSPACE_EXTERNAL) {
@@ -158,9 +215,10 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
158215

159216
// This is called by the workspace_helper, and is not part of the user API.
160217
virtual std::int64_t get_workspace_external_bytes_impl() override {
161-
std::size_t workspaceSize = 0;
162-
handle.get_value(dft::config_param::WORKSPACE_BYTES, &workspaceSize);
163-
return static_cast<std::int64_t>(workspaceSize);
218+
std::size_t workspaceSizeFwd = 0, workspaceSizeBwd = 0;
219+
handle.first->get_value(dft::config_param::WORKSPACE_BYTES, &workspaceSizeFwd);
220+
handle.second->get_value(dft::config_param::WORKSPACE_BYTES, &workspaceSizeBwd);
221+
return static_cast<std::int64_t>(std::max(workspaceSizeFwd, workspaceSizeFwd));
164222
}
165223
};
166224
} // namespace detail

src/dft/backends/mklgpu/forward.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,15 @@ namespace detail {
4848
template <dft::detail::precision prec, dft::detail::domain dom, typename... ArgTs>
4949
inline auto compute_forward(dft::detail::descriptor<prec, dom> &desc, ArgTs &&... args) {
5050
using mklgpu_desc_t = dft::descriptor<to_mklgpu(prec), to_mklgpu(dom)>;
51+
using desc_shptr_t = std::shared_ptr<mklgpu_desc_t>;
52+
using handle_t = std::pair<desc_shptr_t, desc_shptr_t>;
5153
auto commit_handle = dft::detail::get_commit(desc);
5254
if (commit_handle == nullptr || commit_handle->get_backend() != backend::mklgpu) {
5355
throw mkl::invalid_argument("DFT", "compute_forward",
5456
"DFT descriptor has not been commited for MKLGPU");
5557
}
56-
auto mklgpu_desc = reinterpret_cast<mklgpu_desc_t *>(commit_handle->get_handle());
58+
auto handle = reinterpret_cast<handle_t *>(commit_handle->get_handle());
59+
auto mklgpu_desc = handle->first; // First because forward DFT.
5760
int commit_status{ DFTI_UNCOMMITTED };
5861
mklgpu_desc->get_value(dft::config_param::COMMIT_STATUS, &commit_status);
5962
if (commit_status != DFTI_COMMITTED) {

0 commit comments

Comments
 (0)