Skip to content

Commit 0d301e6

Browse files
hjabirdnormallytangent
authored andcommitted
[DFT] Correct overload resolution for OOP COMPLEX vs IP REAL_REAL (#503)
* OOP COMPLEX and IP REAL_REAL overload resolution is problematic * Inplace real-real overload would be selected when out-of-place complex-complex DFT was intended. * With spec update, this PR uses SFINAE to give the expected behaviour for the user.
1 parent 08adee2 commit 0d301e6

File tree

3 files changed

+28
-8
lines changed

3 files changed

+28
-8
lines changed

include/oneapi/mkl/dft/backward.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ void compute_backward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout)
4444
}
4545

4646
//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
47-
template <typename descriptor_type, typename data_type>
47+
template <typename descriptor_type, typename data_type,
48+
std::enable_if_t<detail::valid_ip_realreal_impl<descriptor_type, data_type>, bool> = true>
4849
void compute_backward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout_re,
4950
sycl::buffer<data_type, 1> &inout_im) {
5051
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
@@ -114,7 +115,8 @@ sycl::event compute_backward(descriptor_type &desc, data_type *inout,
114115
}
115116

116117
//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
117-
template <typename descriptor_type, typename data_type>
118+
template <typename descriptor_type, typename data_type,
119+
std::enable_if_t<detail::valid_ip_realreal_impl<descriptor_type, data_type>, bool> = true>
118120
sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, data_type *inout_im,
119121
const std::vector<sycl::event> &dependencies = {}) {
120122
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,

include/oneapi/mkl/dft/detail/types_impl.hpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,34 @@ struct descriptor_info<descriptor<precision::DOUBLE, domain::COMPLEX>> {
8787
using backward_type = std::complex<double>;
8888
};
8989

90+
// Get the scalar type associated with a descriptor.
91+
template <class descriptor_t>
92+
using descriptor_scalar_t = typename descriptor_info<descriptor_t>::scalar_type;
93+
94+
template <typename T>
95+
constexpr bool is_complex_dft = false;
96+
template <precision Prec>
97+
constexpr bool is_complex_dft<descriptor<Prec, domain::COMPLEX>> = true;
98+
99+
template <typename T>
100+
constexpr bool is_complex = false;
101+
template <typename T>
102+
constexpr bool is_complex<std::complex<T>> = true;
103+
90104
template <typename T, typename... Ts>
91105
using is_one_of = typename std::bool_constant<(std::is_same_v<T, Ts> || ...)>;
92106

93107
template <typename descriptor_type, typename T>
94108
using valid_compute_arg = typename std::bool_constant<
95-
(std::is_same_v<typename detail::descriptor_info<descriptor_type>::scalar_type, float> &&
109+
(std::is_same_v<descriptor_scalar_t<descriptor_type>, float> &&
96110
is_one_of<T, float, sycl::float2, sycl::float4, std::complex<float>>::value) ||
97-
(std::is_same_v<typename detail::descriptor_info<descriptor_type>::scalar_type, double> &&
111+
(std::is_same_v<descriptor_scalar_t<descriptor_type>, double> &&
98112
is_one_of<T, double, sycl::double2, sycl::double4, std::complex<double>>::value)>;
99113

114+
template <class descriptor_t, typename data_t>
115+
constexpr bool valid_ip_realreal_impl =
116+
is_complex_dft<descriptor_t>&& std::is_same_v<descriptor_scalar_t<descriptor_t>, data_t>;
117+
100118
// compute the range of a reinterpreted buffer
101119
template <typename In, typename Out>
102120
std::size_t reinterpret_range(std::size_t size) {

include/oneapi/mkl/dft/forward.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ void compute_forward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout) {
4545
}
4646

4747
//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
48-
template <typename descriptor_type, typename data_type>
48+
template <typename descriptor_type, typename data_type,
49+
std::enable_if_t<detail::valid_ip_realreal_impl<descriptor_type, data_type>, bool> = true>
4950
void compute_forward(descriptor_type &desc, sycl::buffer<data_type, 1> &inout_re,
5051
sycl::buffer<data_type, 1> &inout_im) {
5152
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
@@ -114,12 +115,12 @@ sycl::event compute_forward(descriptor_type &desc, data_type *inout,
114115
}
115116

116117
//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
117-
template <typename descriptor_type, typename data_type>
118+
template <typename descriptor_type, typename data_type,
119+
std::enable_if_t<detail::valid_ip_realreal_impl<descriptor_type, data_type>, bool> = true>
118120
sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, data_type *inout_im,
119121
const std::vector<sycl::event> &dependencies = {}) {
120122
static_assert(detail::valid_compute_arg<descriptor_type, data_type>::value,
121123
"unexpected type for data_type");
122-
123124
using scalar_type = typename detail::descriptor_info<descriptor_type>::scalar_type;
124125
return get_commit(desc)->forward_ip_rr(desc, reinterpret_cast<scalar_type *>(inout_re),
125126
reinterpret_cast<scalar_type *>(inout_im), dependencies);
@@ -133,7 +134,6 @@ sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *
133134
"unexpected type for input_type");
134135
static_assert(detail::valid_compute_arg<descriptor_type, output_type>::value,
135136
"unexpected type for output_type");
136-
137137
using fwd_type = typename detail::descriptor_info<descriptor_type>::forward_type;
138138
using bwd_type = typename detail::descriptor_info<descriptor_type>::backward_type;
139139
return get_commit(desc)->forward_op_cc(desc, reinterpret_cast<fwd_type *>(in),

0 commit comments

Comments
 (0)