Skip to content

Commit b86d5fc

Browse files
authored
[enhancement] Refactor onedal/datatypes in preparation for dlpack support (#2195)
* move to numpy and sycl_usm folders * fix pre-commit * fix pickling * forgotten numpy namespace is BS * missing another numpy * missing another numpy * add missing sycl_usm * remove table_metadata * remove unneeded includes and move dtype_dispatcher into a central location * remove header reference * missed save * Delete onedal/datatypes/dtype_dispatcher.hpp * Revert "Delete onedal/datatypes/dtype_dispatcher.hpp" This reverts commit bfc66b6. * helper -> utils * move macro to a central spot
1 parent 1b6d537 commit b86d5fc

File tree

17 files changed

+130
-130
lines changed

17 files changed

+130
-130
lines changed

onedal/basic_statistics/basic_statistics.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#include "onedal/version.hpp"
2121

2222
#define NO_IMPORT_ARRAY // import_array called in table.cpp
23-
#include "onedal/datatypes/data_conversion.hpp"
23+
#include "onedal/datatypes/numpy/data_conversion.hpp"
2424

2525
#include <string>
2626
#include <regex>
@@ -210,30 +210,30 @@ void init_partial_compute_result(py::module_& m) {
210210
.def(py::pickle(
211211
[](const result_t& res) {
212212
return py::make_tuple(
213-
py::cast<py::object>(convert_to_pyobject(res.get_partial_n_rows())),
214-
py::cast<py::object>(convert_to_pyobject(res.get_partial_min())),
215-
py::cast<py::object>(convert_to_pyobject(res.get_partial_max())),
216-
py::cast<py::object>(convert_to_pyobject(res.get_partial_sum())),
217-
py::cast<py::object>(convert_to_pyobject(res.get_partial_sum_squares())),
213+
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_n_rows())),
214+
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_min())),
215+
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_max())),
216+
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_sum())),
217+
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_sum_squares())),
218218
py::cast<py::object>(
219-
convert_to_pyobject(res.get_partial_sum_squares_centered())));
219+
numpy::convert_to_pyobject(res.get_partial_sum_squares_centered())));
220220
},
221221
[](py::tuple t) {
222222
if (t.size() != 6)
223223
throw std::runtime_error("Invalid state!");
224224
result_t res;
225225
if (py::cast<int>(t[0].attr("size")) != 0)
226-
res.set_partial_n_rows(convert_to_table(t[0]));
226+
res.set_partial_n_rows(numpy::convert_to_table(t[0]));
227227
if (py::cast<int>(t[1].attr("size")) != 0)
228-
res.set_partial_min(convert_to_table(t[1]));
228+
res.set_partial_min(numpy::convert_to_table(t[1]));
229229
if (py::cast<int>(t[2].attr("size")) != 0)
230-
res.set_partial_max(convert_to_table(t[2]));
230+
res.set_partial_max(numpy::convert_to_table(t[2]));
231231
if (py::cast<int>(t[3].attr("size")) != 0)
232-
res.set_partial_sum(convert_to_table(t[3]));
232+
res.set_partial_sum(numpy::convert_to_table(t[3]));
233233
if (py::cast<int>(t[4].attr("size")) != 0)
234-
res.set_partial_sum_squares(convert_to_table(t[4]));
234+
res.set_partial_sum_squares(numpy::convert_to_table(t[4]));
235235
if (py::cast<int>(t[5].attr("size")) != 0)
236-
res.set_partial_sum_squares_centered(convert_to_table(t[5]));
236+
res.set_partial_sum_squares_centered(numpy::convert_to_table(t[5]));
237237

238238
return res;
239239
}));

onedal/covariance/covariance.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include "oneapi/dal/algo/covariance.hpp"
2020

2121
#define NO_IMPORT_ARRAY // import_array called in table.cpp
22-
#include "onedal/datatypes/data_conversion.hpp"
22+
#include "onedal/datatypes/numpy/data_conversion.hpp"
2323

2424
#include "onedal/common.hpp"
2525
#include "onedal/version.hpp"
@@ -141,20 +141,21 @@ inline void init_partial_compute_result(pybind11::module_& m) {
141141
.def(py::pickle(
142142
[](const result_t& res) {
143143
return py::make_tuple(
144-
py::cast<py::object>(convert_to_pyobject(res.get_partial_n_rows())),
145-
py::cast<py::object>(convert_to_pyobject(res.get_partial_crossproduct())),
146-
py::cast<py::object>(convert_to_pyobject(res.get_partial_sum())));
144+
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_n_rows())),
145+
py::cast<py::object>(
146+
numpy::convert_to_pyobject(res.get_partial_crossproduct())),
147+
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_sum())));
147148
},
148149
[](py::tuple t) {
149150
if (t.size() != 3)
150151
throw std::runtime_error("Invalid state!");
151152
result_t res;
152153
if (py::cast<int>(t[0].attr("size")) != 0)
153-
res.set_partial_n_rows(convert_to_table(t[0]));
154+
res.set_partial_n_rows(numpy::convert_to_table(t[0]));
154155
if (py::cast<int>(t[1].attr("size")) != 0)
155-
res.set_partial_crossproduct(convert_to_table(t[1]));
156+
res.set_partial_crossproduct(numpy::convert_to_table(t[1]));
156157
if (py::cast<int>(t[2].attr("size")) != 0)
157-
res.set_partial_sum(convert_to_table(t[2]));
158+
res.set_partial_sum(numpy::convert_to_table(t[2]));
158159
return res;
159160
}));
160161
}

onedal/datatypes/utils/dtype_dispatcher.hpp renamed to onedal/datatypes/dtype_dispatcher.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,27 @@ constexpr inline void apply(Op&& op, Args&&... args) {
8686

8787
#endif // Version check
8888

89+
#define SET_CTYPE_FROM_DAL_TYPE(_T, _FUNCT, _EXCEPTION) \
90+
switch (_T) { \
91+
case dal::data_type::float32: { \
92+
_FUNCT(float); \
93+
break; \
94+
} \
95+
case dal::data_type::float64: { \
96+
_FUNCT(double); \
97+
break; \
98+
} \
99+
case dal::data_type::int32: { \
100+
_FUNCT(std::int32_t); \
101+
break; \
102+
} \
103+
case dal::data_type::int64: { \
104+
_FUNCT(std::int64_t); \
105+
break; \
106+
} \
107+
default: _EXCEPTION; \
108+
};
109+
89110
namespace oneapi::dal::python {
90111

91112
using supported_types_t = std::tuple<float,

onedal/datatypes/data_conversion.cpp renamed to onedal/datatypes/numpy/data_conversion.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
#include "oneapi/dal/table/homogen.hpp"
2323
#include "oneapi/dal/table/detail/homogen_utils.hpp"
2424

25-
#include "onedal/datatypes/data_conversion.hpp"
26-
#include "onedal/datatypes/utils/numpy_helpers.hpp"
25+
#include "onedal/datatypes/numpy/data_conversion.hpp"
26+
#include "onedal/datatypes/numpy/numpy_utils.hpp"
2727
#include "onedal/version.hpp"
2828

2929
#if ONEDAL_VERSION <= 20230100
@@ -32,7 +32,7 @@
3232
#include "oneapi/dal/table/csr.hpp"
3333
#endif
3434

35-
namespace oneapi::dal::python {
35+
namespace oneapi::dal::python::numpy {
3636

3737
#if ONEDAL_VERSION <= 20230100
3838
typedef oneapi::dal::detail::csr_table csr_table_t;
@@ -432,4 +432,4 @@ PyObject *convert_to_pyobject(const dal::table &input) {
432432
return res;
433433
}
434434

435-
} // namespace oneapi::dal::python
435+
} // namespace oneapi::dal::python::numpy

onedal/datatypes/data_conversion.hpp renamed to onedal/datatypes/numpy/data_conversion.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525

2626
#include "oneapi/dal/table/common.hpp"
2727

28-
namespace oneapi::dal::python {
28+
namespace oneapi::dal::python::numpy {
2929

3030
namespace py = pybind11;
3131

3232
PyObject *convert_to_pyobject(const dal::table &input);
3333
dal::table convert_to_table(py::object inp_obj, py::object queue = py::none());
3434

35-
} // namespace oneapi::dal::python
35+
} // namespace oneapi::dal::python::numpy

onedal/datatypes/utils/numpy_helpers.cpp renamed to onedal/datatypes/numpy/numpy_utils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
* limitations under the License.
1515
*******************************************************************************/
1616

17-
#include "onedal/datatypes/utils/numpy_helpers.hpp"
17+
#include "onedal/datatypes/numpy/numpy_utils.hpp"
1818

19-
namespace oneapi::dal::python {
19+
namespace oneapi::dal::python::numpy {
2020

2121
template <typename Key, typename Value>
2222
auto reverse_map(const std::map<Key, Value>& input) {
@@ -50,4 +50,4 @@ npy_dtype_t convert_dal_to_npy_type(dal::data_type type) {
5050
return get_dal_to_npy_map().at(type);
5151
}
5252

53-
} // namespace oneapi::dal::python
53+
} // namespace oneapi::dal::python::numpy

onedal/datatypes/utils/numpy_helpers.hpp renamed to onedal/datatypes/numpy/numpy_utils.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@
140140
#define array_data(a) PyArray_DATA((PyArrayObject *)a)
141141
#define array_size(a, i) PyArray_DIM((PyArrayObject *)a, i)
142142

143-
namespace oneapi::dal::python {
143+
namespace oneapi::dal::python::numpy {
144144

145145
using npy_dtype_t = decltype(NPY_FLOAT);
146146
using npy_to_dal_t = std::map<npy_dtype_t, dal::data_type>;
@@ -152,4 +152,4 @@ const dal_to_npy_t &get_dal_to_npy_map();
152152
dal::data_type convert_npy_to_dal_type(npy_dtype_t);
153153
npy_dtype_t convert_dal_to_npy_type(dal::data_type);
154154

155-
} // namespace oneapi::dal::python
155+
} // namespace oneapi::dal::python::numpy

onedal/datatypes/data_conversion_sua_iface.cpp renamed to onedal/datatypes/sycl_usm/data_conversion.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,11 @@
2727
#include "oneapi/dal/table/detail/homogen_utils.hpp"
2828

2929
#include "onedal/common/sycl_interfaces.hpp"
30-
#include "onedal/datatypes/data_conversion_sua_iface.hpp"
31-
#include "onedal/datatypes/utils/dtype_conversions.hpp"
32-
#include "onedal/datatypes/utils/dtype_dispatcher.hpp"
33-
#include "onedal/datatypes/utils/sua_iface_helpers.hpp"
30+
#include "onedal/datatypes/sycl_usm/data_conversion.hpp"
31+
#include "onedal/datatypes/sycl_usm/dtype_conversion.hpp"
32+
#include "onedal/datatypes/sycl_usm/sycl_usm_utils.hpp"
3433

35-
namespace oneapi::dal::python {
34+
namespace oneapi::dal::python::sycl_usm {
3635

3736
using namespace pybind11::literals;
3837
// Please follow <https://intelpython.github.io/dpctl/latest/
@@ -128,7 +127,7 @@ dal::table convert_to_homogen_impl(py::object obj) {
128127
}
129128

130129
// Convert oneDAL table with zero-copy by use of `__sycl_usm_array_interface__` protocol.
131-
dal::table convert_from_sua_iface(py::object obj) {
130+
dal::table convert_to_table(py::object obj) {
132131
// Get `__sycl_usm_array_interface__` dictionary representing USM allocations.
133132
auto sua_iface_dict = get_sua_interface(obj);
134133

@@ -236,6 +235,6 @@ void define_sycl_usm_array_property(py::class_<dal::table>& table_obj) {
236235
table_obj.def_property_readonly("__sycl_usm_array_interface__", &construct_sua_iface);
237236
}
238237

239-
} // namespace oneapi::dal::python
238+
} // namespace oneapi::dal::python::sycl_usm
240239

241240
#endif // ONEDAL_DATA_PARALLEL

onedal/datatypes/data_conversion_sua_iface.hpp renamed to onedal/datatypes/sycl_usm/data_conversion.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323

2424
#include "oneapi/dal/table/common.hpp"
2525

26-
namespace oneapi::dal::python {
26+
namespace oneapi::dal::python::sycl_usm {
2727

2828
namespace py = pybind11;
2929

3030
// Convert oneDAL table with zero-copy by use of `__sycl_usm_array_interface__` protocol.
31-
dal::table convert_from_sua_iface(py::object obj);
31+
dal::table convert_to_table(py::object obj);
3232

3333
// Create a dictionary for `__sycl_usm_array_interface__` protocol from oneDAL table properties.
3434
py::dict construct_sua_iface(const dal::table& input);
@@ -37,4 +37,4 @@ py::dict construct_sua_iface(const dal::table& input);
3737
// USM allocations.
3838
void define_sycl_usm_array_property(py::class_<dal::table>& t);
3939

40-
} // namespace oneapi::dal::python
40+
} // namespace oneapi::dal::python::sycl_usm

onedal/datatypes/utils/dtype_conversions.cpp renamed to onedal/datatypes/sycl_usm/dtype_conversion.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
#include "oneapi/dal/common.hpp"
2222
#include "oneapi/dal/detail/common.hpp"
2323

24-
#include "onedal/datatypes/utils/dtype_conversions.hpp"
25-
#include "onedal/datatypes/utils/dtype_dispatcher.hpp"
24+
#include "onedal/datatypes/sycl_usm/dtype_conversion.hpp"
25+
#include "onedal/datatypes/dtype_dispatcher.hpp"
2626

27-
namespace oneapi::dal::python {
27+
namespace oneapi::dal::python::sycl_usm {
2828

2929
using fwd_map_t = std::unordered_map<std::string, dal::data_type>;
3030
using inv_map_t = std::unordered_map<dal::data_type, std::string>;
@@ -139,4 +139,4 @@ std::string convert_dal_to_sua_type(dal::data_type dtype) {
139139
return get_inv_map().at(dtype);
140140
}
141141

142-
} // namespace oneapi::dal::python
142+
} // namespace oneapi::dal::python::sycl_usm

0 commit comments

Comments
 (0)