forked from uxlfoundation/oneMath
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmklgpu_helpers.hpp
More file actions
156 lines (144 loc) · 6.04 KB
/
mklgpu_helpers.hpp
File metadata and controls
156 lines (144 loc) · 6.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
/*******************************************************************************
* Copyright Codeplay Software Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions
* and limitations under the License.
*
*
* SPDX-License-Identifier: Apache-2.0
*******************************************************************************/
#ifndef _ONEMATH_DFT_SRC_MKLGPU_HELPERS_HPP_
#define _ONEMATH_DFT_SRC_MKLGPU_HELPERS_HPP_
#include "oneapi/math/detail/exceptions.hpp"
#include "oneapi/math/dft/detail/types_impl.hpp"
// Intel(R) oneMKL headers
#include <mkl_version.h>
#if INTEL_MKL_VERSION < 20250000
#include <mkl/dfti.hpp>
namespace oneapi::math::dft::mklgpu::detail {
constexpr int committed = DFTI_COMMITTED;
constexpr int uncommitted = DFTI_UNCOMMITTED;
} // namespace oneapi::math::dft::mklgpu::detail
#else
#include <mkl/dft.hpp>
namespace oneapi::math::dft::mklgpu::detail {
constexpr auto committed = oneapi::mkl::dft::config_value::COMMITTED;
constexpr auto uncommitted = oneapi::mkl::dft::config_value::UNCOMMITTED;
} // namespace oneapi::math::dft::mklgpu::detail
#endif
namespace oneapi {
namespace math {
namespace dft {
namespace mklgpu {
namespace detail {
/// Convert domain to equivalent backend native value.
inline constexpr oneapi::mkl::dft::domain to_mklgpu(dft::detail::domain dom) {
if (dom == dft::detail::domain::REAL) {
return oneapi::mkl::dft::domain::REAL;
}
else {
return oneapi::mkl::dft::domain::COMPLEX;
}
}
/// Convert precision to equivalent backend native value.
inline constexpr oneapi::mkl::dft::precision to_mklgpu(dft::detail::precision dom) {
if (dom == dft::detail::precision::SINGLE) {
return oneapi::mkl::dft::precision::SINGLE;
}
else {
return oneapi::mkl::dft::precision::DOUBLE;
}
}
/// Convert a config_param to equivalent backend native value.
/*inline constexpr oneapi::mkl::dft::config_param to_mklgpu(dft::detail::config_param param) {
using iparam = dft::detail::config_param;
using oparam = oneapi::mkl::dft::config_param;
switch (param) {
case iparam::FORWARD_DOMAIN: return oparam::FORWARD_DOMAIN;
case iparam::DIMENSION: return oparam::DIMENSION;
case iparam::LENGTHS: return oparam::LENGTHS;
case iparam::PRECISION: return oparam::PRECISION;
case iparam::FORWARD_SCALE: return oparam::FORWARD_SCALE;
case iparam::NUMBER_OF_TRANSFORMS: return oparam::NUMBER_OF_TRANSFORMS;
case iparam::COMPLEX_STORAGE: return oparam::COMPLEX_STORAGE;
case iparam::CONJUGATE_EVEN_STORAGE: return oparam::CONJUGATE_EVEN_STORAGE;
case iparam::FWD_DISTANCE: return oparam::FWD_DISTANCE;
case iparam::BWD_DISTANCE: return oparam::BWD_DISTANCE;
case iparam::WORKSPACE: return oparam::WORKSPACE;
case iparam::PACKED_FORMAT: return oparam::PACKED_FORMAT;
case iparam::WORKSPACE_PLACEMENT: return oparam::WORKSPACE; // Same as WORKSPACE
case iparam::WORKSPACE_EXTERNAL_BYTES: return oparam::WORKSPACE_BYTES;
case iparam::COMMIT_STATUS: return oparam::COMMIT_STATUS;
default:
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config param.");
return static_cast<oparam>(0);
}
}*/
template <dft::detail::config_param Param>
struct to_mklgpu_impl;
/** Convert a config_value to the backend's native value. Throw on invalid input.
* @tparam Param The config param the value is for.
* @param value The config value to convert.
**/
template <dft::detail::config_param Param>
inline constexpr auto to_mklgpu(dft::detail::config_value value) {
return to_mklgpu_impl<Param>{}(value);
}
#if INTEL_MKL_VERSION < 20250000
template <>
struct to_mklgpu_impl<dft::detail::config_param::PLACEMENT> {
inline constexpr auto operator()(dft::detail::config_value value) -> int {
switch (value) {
case dft::detail::config_value::INPLACE: return DFTI_INPLACE;
case dft::detail::config_value::NOT_INPLACE: return DFTI_NOT_INPLACE;
default:
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for inplace.");
}
}
};
#else
template <>
struct to_mklgpu_impl<dft::detail::config_param::PLACEMENT> {
inline constexpr auto operator()(dft::detail::config_value value) {
switch (value) {
case dft::detail::config_value::INPLACE: return oneapi::mkl::dft::config_value::INPLACE;
case dft::detail::config_value::NOT_INPLACE:
return oneapi::mkl::dft::config_value::NOT_INPLACE;
default:
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for inplace.");
}
}
};
#endif
template <>
struct to_mklgpu_impl<dft::detail::config_param::WORKSPACE_PLACEMENT> {
inline constexpr auto operator()(dft::detail::config_value value) {
switch (value) {
case dft::detail::config_value::WORKSPACE_AUTOMATIC:
return oneapi::mkl::dft::config_value::WORKSPACE_INTERNAL;
case dft::detail::config_value::WORKSPACE_EXTERNAL:
return oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL;
default:
throw math::invalid_argument("dft", "MKLGPU descriptor set_value()",
"Invalid config value for inplace.");
}
}
};
} // namespace detail
} // namespace mklgpu
} // namespace dft
} // namespace math
} // namespace oneapi
#endif // _ONEMATH_DFT_SRC_MKLGPU_HELPERS_HPP_