Skip to content

Commit f58963c

Browse files
Merge branch 'develop' of github.com:normallytangent/oneMKL into develop
2 parents e420e88 + ccd3822 commit f58963c

24 files changed

+818
-515
lines changed

include/oneapi/mkl/dft/detail/commit_impl.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ enum class backend;
3232

3333
namespace oneapi::mkl::dft::detail {
3434

35+
enum class precision;
36+
enum class domain;
37+
template <precision prec, domain dom>
38+
class dft_values;
39+
40+
template <precision prec, domain dom>
3541
class commit_impl {
3642
public:
3743
commit_impl(sycl::queue queue, mkl::backend backend) : queue_(queue), backend_(backend) {}
@@ -51,6 +57,8 @@ class commit_impl {
5157

5258
virtual void* get_handle() noexcept = 0;
5359

60+
virtual void commit(const dft_values<prec, dom>&) = 0;
61+
5462
private:
5563
mkl::backend backend_;
5664
sycl::queue queue_;

include/oneapi/mkl/dft/detail/descriptor_impl.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ template <precision prec, domain dom>
4141
class descriptor;
4242

4343
template <precision prec, domain dom>
44-
inline commit_impl* get_commit(descriptor<prec, dom>& desc);
44+
inline commit_impl<prec, dom>* get_commit(descriptor<prec, dom>& desc);
4545

4646
template <precision prec, domain dom>
4747
class descriptor {
@@ -74,16 +74,16 @@ class descriptor {
7474

7575
private:
7676
// Has a value when the descriptor is committed.
77-
std::unique_ptr<commit_impl> pimpl_;
77+
std::unique_ptr<commit_impl<prec, dom>> pimpl_;
7878

7979
// descriptor configuration values_ and structs
8080
dft_values<prec, dom> values_;
8181

82-
friend commit_impl* get_commit<prec, dom>(descriptor<prec, dom>&);
82+
friend commit_impl<prec, dom>* get_commit<prec, dom>(descriptor<prec, dom>&);
8383
};
8484

8585
template <precision prec, domain dom>
86-
inline commit_impl* get_commit(descriptor<prec, dom>& desc) {
86+
inline commit_impl<prec, dom>* get_commit(descriptor<prec, dom>& desc) {
8787
return desc.pimpl_.get();
8888
}
8989

include/oneapi/mkl/dft/detail/dft_ct.hxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
// Commit
2121

2222
template <dft::detail::precision prec, dft::detail::domain dom>
23-
ONEMKL_EXPORT dft::detail::commit_impl *create_commit(
23+
ONEMKL_EXPORT dft::detail::commit_impl<prec, dom> *create_commit(
2424
const dft::detail::descriptor<prec, dom> &desc, sycl::queue &sycl_queue);
2525

2626
// BUFFER version

include/oneapi/mkl/dft/detail/dft_loader.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@ namespace mkl {
3434
namespace dft {
3535
namespace detail {
3636

37+
template <precision prec, domain dom>
3738
class commit_impl;
3839

3940
template <precision prec, domain dom>
4041
class descriptor;
4142

4243
template <precision prec, domain dom>
43-
ONEMKL_EXPORT commit_impl* create_commit(const descriptor<prec, dom>& desc, sycl::queue& queue);
44+
ONEMKL_EXPORT commit_impl<prec, dom>* create_commit(const descriptor<prec, dom>& desc,
45+
sycl::queue& queue);
4446

4547
} // namespace detail
4648
} // namespace dft

include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ namespace dft {
3535

3636
namespace detail {
3737
// Forward declarations
38+
template <precision prec, domain dom>
3839
class commit_impl;
3940

4041
template <precision prec, domain dom>

include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ namespace dft {
3535

3636
namespace detail {
3737
// Forward declarations
38+
template <precision prec, domain dom>
3839
class commit_impl;
3940

4041
template <precision prec, domain dom>

src/dft/backends/descriptor.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@ namespace dft {
2828

2929
template <precision prec, domain dom>
3030
void descriptor<prec, dom>::commit(sycl::queue &queue) {
31-
pimpl_.reset(detail::create_commit(*this, queue));
31+
if (!pimpl_ || pimpl_->get_queue() != queue) {
32+
if (pimpl_) {
33+
pimpl_->get_queue().wait();
34+
}
35+
pimpl_.reset(detail::create_commit(*this, queue));
36+
}
37+
pimpl_->commit(values_);
3238
}
3339
template void descriptor<precision::SINGLE, domain::COMPLEX>::commit(sycl::queue &);
3440
template void descriptor<precision::SINGLE, domain::REAL>::commit(sycl::queue &);

src/dft/backends/mklcpu/commit.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ namespace dft {
4040
namespace mklcpu {
4141

4242
template <precision prec, domain dom>
43-
class commit_derived_impl final : public detail::commit_impl {
43+
class commit_derived_impl final : public detail::commit_impl<prec, dom> {
4444
public:
4545
commit_derived_impl(sycl::queue queue, const detail::dft_values<prec, dom>& config_values)
46-
: detail::commit_impl(queue, backend::mklcpu) {
46+
: detail::commit_impl<prec, dom>(queue, backend::mklcpu) {
4747
DFT_ERROR status = DFT_NOTSET;
4848
if (config_values.dimensions.size() == 1) {
4949
status = DftiCreateDescriptor(&handle, get_precision(prec), get_domain(dom), 1,
@@ -55,16 +55,19 @@ class commit_derived_impl final : public detail::commit_impl {
5555
config_values.dimensions.data());
5656
}
5757
if (status != DFTI_NO_ERROR) {
58-
throw oneapi::mkl::exception("dft/backends/mklcpu", "commit",
59-
"DftiCreateDescriptor failed");
58+
throw oneapi::mkl::exception(
59+
"dft/backends/mklcpu", "commit",
60+
"DftiCreateDescriptor failed with status: " + std::to_string(status));
6061
}
62+
}
6163

64+
void commit(const detail::dft_values<prec, dom>& config_values) override {
6265
set_value(handle, config_values);
63-
64-
status = DftiCommitDescriptor(handle);
66+
auto status = DftiCommitDescriptor(handle);
6567
if (status != DFTI_NO_ERROR) {
66-
throw oneapi::mkl::exception("dft/backends/mklcpu", "commit",
67-
"DftiCommitDescriptor failed");
68+
throw oneapi::mkl::exception(
69+
"dft/backends/mklcpu", "commit",
70+
"DftiCommitDescriptor failed with status: " + std::to_string(status));
6871
}
6972
}
7073

@@ -122,18 +125,19 @@ class commit_derived_impl final : public detail::commit_impl {
122125
};
123126

124127
template <precision prec, domain dom>
125-
detail::commit_impl* create_commit(const descriptor<prec, dom>& desc, sycl::queue& sycl_queue) {
128+
detail::commit_impl<prec, dom>* create_commit(const descriptor<prec, dom>& desc,
129+
sycl::queue& sycl_queue) {
126130
return new commit_derived_impl<prec, dom>(sycl_queue, desc.get_values());
127131
}
128132

129-
template detail::commit_impl* create_commit(const descriptor<precision::SINGLE, domain::REAL>&,
130-
sycl::queue&);
131-
template detail::commit_impl* create_commit(const descriptor<precision::SINGLE, domain::COMPLEX>&,
132-
sycl::queue&);
133-
template detail::commit_impl* create_commit(const descriptor<precision::DOUBLE, domain::REAL>&,
134-
sycl::queue&);
135-
template detail::commit_impl* create_commit(const descriptor<precision::DOUBLE, domain::COMPLEX>&,
136-
sycl::queue&);
133+
template detail::commit_impl<precision::SINGLE, domain::REAL>* create_commit(
134+
const descriptor<precision::SINGLE, domain::REAL>&, sycl::queue&);
135+
template detail::commit_impl<precision::SINGLE, domain::COMPLEX>* create_commit(
136+
const descriptor<precision::SINGLE, domain::COMPLEX>&, sycl::queue&);
137+
template detail::commit_impl<precision::DOUBLE, domain::REAL>* create_commit(
138+
const descriptor<precision::DOUBLE, domain::REAL>&, sycl::queue&);
139+
template detail::commit_impl<precision::DOUBLE, domain::COMPLEX>* create_commit(
140+
const descriptor<precision::DOUBLE, domain::COMPLEX>&, sycl::queue&);
137141

138142
} // namespace mklcpu
139143
} // namespace dft

src/dft/backends/mklcpu/descriptor.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@ namespace dft {
2828

2929
template <precision prec, domain dom>
3030
void descriptor<prec, dom>::commit(backend_selector<backend::mklcpu> selector) {
31-
pimpl_.reset(mklcpu::create_commit(*this, selector.get_queue()));
31+
if (!pimpl_ || pimpl_->get_queue() != selector.get_queue()) {
32+
if (pimpl_) {
33+
pimpl_->get_queue().wait();
34+
}
35+
pimpl_.reset(mklgpu::create_commit(*this, selector.get_queue()));
36+
}
37+
pimpl_->commit(values_);
3238
}
3339

3440
template void descriptor<precision::SINGLE, domain::COMPLEX>::commit(

src/dft/backends/mklgpu/commit.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ namespace detail {
5151

5252
/// Commit impl class specialization for MKLGPU.
5353
template <dft::detail::precision prec, dft::detail::domain dom>
54-
class commit_derived_impl final : public dft::detail::commit_impl {
54+
class commit_derived_impl final : public dft::detail::commit_impl<prec, dom> {
5555
private:
5656
// Equivalent MKLGPU precision and domain from OneMKL's precision / domain.
5757
static constexpr dft::precision mklgpu_prec = to_mklgpu(prec);
@@ -60,19 +60,21 @@ class commit_derived_impl final : public dft::detail::commit_impl {
6060

6161
public:
6262
commit_derived_impl(sycl::queue queue, const dft::detail::dft_values<prec, dom>& config_values)
63-
: oneapi::mkl::dft::detail::commit_impl(queue, backend::mklgpu),
63+
: oneapi::mkl::dft::detail::commit_impl<prec, dom>(queue, backend::mklgpu),
6464
handle(config_values.dimensions) {
65-
set_value(handle, config_values);
6665
// MKLGPU does not throw an informative exception for the following:
6766
if constexpr (prec == dft::detail::precision::DOUBLE) {
6867
if (!queue.get_device().has(sycl::aspect::fp64)) {
6968
throw mkl::exception("dft/backends/mklgpu", "commit",
7069
"Device does not support double precision.");
7170
}
7271
}
72+
}
7373

74+
virtual void commit(const dft::detail::dft_values<prec, dom>& config_values) override {
75+
set_value(handle, config_values);
7476
try {
75-
handle.commit(queue);
77+
handle.commit(this->get_queue());
7678
}
7779
catch (const std::exception& mkl_exception) {
7880
// Catching the real MKL exception causes headaches with naming.
@@ -125,28 +127,30 @@ class commit_derived_impl final : public dft::detail::commit_impl {
125127
throw mkl::invalid_argument("dft/backends/mklgpu", "commit",
126128
"MKLGPU only supports non-transposed.");
127129
}
128-
desc.set_value(backend_param::PACKED_FORMAT,
129-
to_mklgpu<onemkl_param::PACKED_FORMAT>(config.packed_format));
130130
}
131131
};
132132
} // namespace detail
133133

134134
template <dft::detail::precision prec, dft::detail::domain dom>
135-
dft::detail::commit_impl* create_commit(const dft::detail::descriptor<prec, dom>& desc,
136-
sycl::queue& sycl_queue) {
135+
dft::detail::commit_impl<prec, dom>* create_commit(const dft::detail::descriptor<prec, dom>& desc,
136+
sycl::queue& sycl_queue) {
137137
return new detail::commit_derived_impl<prec, dom>(sycl_queue, desc.get_values());
138138
}
139139

140-
template dft::detail::commit_impl* create_commit(
140+
template dft::detail::commit_impl<dft::detail::precision::SINGLE, dft::detail::domain::REAL>*
141+
create_commit(
141142
const dft::detail::descriptor<dft::detail::precision::SINGLE, dft::detail::domain::REAL>&,
142143
sycl::queue&);
143-
template dft::detail::commit_impl* create_commit(
144+
template dft::detail::commit_impl<dft::detail::precision::SINGLE, dft::detail::domain::COMPLEX>*
145+
create_commit(
144146
const dft::detail::descriptor<dft::detail::precision::SINGLE, dft::detail::domain::COMPLEX>&,
145147
sycl::queue&);
146-
template dft::detail::commit_impl* create_commit(
148+
template dft::detail::commit_impl<dft::detail::precision::DOUBLE, dft::detail::domain::REAL>*
149+
create_commit(
147150
const dft::detail::descriptor<dft::detail::precision::DOUBLE, dft::detail::domain::REAL>&,
148151
sycl::queue&);
149-
template dft::detail::commit_impl* create_commit(
152+
template dft::detail::commit_impl<dft::detail::precision::DOUBLE, dft::detail::domain::COMPLEX>*
153+
create_commit(
150154
const dft::detail::descriptor<dft::detail::precision::DOUBLE, dft::detail::domain::COMPLEX>&,
151155
sycl::queue&);
152156

0 commit comments

Comments
 (0)