Skip to content

Commit ca28499

Browse files
author
KulikovNikita
authored
ENH: DPCTL native support (#1200)
1 parent 1674bb9 commit ca28499

File tree

7 files changed

+415
-45
lines changed

7 files changed

+415
-45
lines changed

onedal/datatypes/_data_conversion.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,18 @@
1616

1717
import warnings
1818
import numpy as np
19+
20+
from onedal import _is_dpc_backend
1921
from onedal import _backend
2022
from daal4py.sklearn._utils import make2d
2123

24+
try:
25+
import dpctl
26+
import dpctl.tensor as dpt
27+
dpctl_available = dpctl.__version__ >= '0.14'
28+
except ImportError:
29+
dpctl_available = False
30+
2231

2332
def _apply_and_pass(func, *args):
2433
if len(args) == 1:
@@ -29,8 +38,14 @@ def _apply_and_pass(func, *args):
2938
def from_table(*args):
3039
return _apply_and_pass(_backend.from_table, *args)
3140

41+
# TODO:
42+
# refactoring.
43+
3244

3345
def convert_one_to_table(arg):
46+
if dpctl_available:
47+
if isinstance(arg, dpt.usm_ndarray):
48+
return _backend.dpctl_to_table(arg)
3449
arg = make2d(arg)
3550
return _backend.to_table(arg)
3651

@@ -39,8 +54,6 @@ def to_table(*args):
3954
return _apply_and_pass(convert_one_to_table, *args)
4055

4156

42-
from onedal import _is_dpc_backend
43-
4457
if _is_dpc_backend:
4558
from ..common._policy import _HostInteropPolicy
4659

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
/*******************************************************************************
2+
* Copyright 2023 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#ifdef ONEDAL_DPCTL_INTEGRATION
18+
#define NO_IMPORT_ARRAY
19+
20+
#include <stdexcept>
21+
#include <utility>
22+
#include <string>
23+
24+
#include "oneapi/dal/table/homogen.hpp"
25+
#include "oneapi/dal/table/detail/csr.hpp"
26+
#include "oneapi/dal/table/detail/homogen_utils.hpp"
27+
28+
#include "onedal/datatypes/data_conversion_dpctl.hpp"
29+
#include "onedal/datatypes/numpy_helpers.hpp"
30+
31+
#include "dpctl4pybind11.hpp"
32+
33+
namespace oneapi::dal::python {
34+
35+
void report_problem_from_dptensor(const char* clarification) {
36+
constexpr const char* const base_message = "Unable to convert from dptensor";
37+
38+
std::string message{ base_message };
39+
message += std::string{ clarification };
40+
throw std::invalid_argument{ message };
41+
}
42+
43+
std::int64_t get_and_check_dptensor_ndim(const dpctl::tensor::usm_ndarray& tensor) {
44+
constexpr const char* const err_message = ": only 1D & 2D tensors are allowed";
45+
46+
const auto ndim = dal::detail::integral_cast<std::int64_t>(tensor.get_ndim());
47+
if ((ndim != 1) && (ndim != 2))
48+
report_problem_from_dptensor(err_message);
49+
return ndim;
50+
}
51+
52+
auto get_dptensor_shape(const dpctl::tensor::usm_ndarray& tensor) {
53+
const auto ndim = get_and_check_dptensor_ndim(tensor);
54+
std::int64_t row_count, col_count;
55+
if (ndim == 1l) {
56+
row_count = dal::detail::integral_cast<std::int64_t>(tensor.get_shape(0));
57+
col_count = 1l;
58+
}
59+
else {
60+
row_count = dal::detail::integral_cast<std::int64_t>(tensor.get_shape(0));
61+
col_count = dal::detail::integral_cast<std::int64_t>(tensor.get_shape(1));
62+
}
63+
64+
return std::make_pair(row_count, col_count);
65+
}
66+
67+
auto get_dptensor_layout(const dpctl::tensor::usm_ndarray& tensor) {
68+
const auto ndim = get_and_check_dptensor_ndim(tensor);
69+
const bool is_c_cont = tensor.is_c_contiguous();
70+
const bool is_f_cont = tensor.is_f_contiguous();
71+
72+
if (ndim == 1l) {
73+
//if (!is_c_cont || !is_f_cont) report_problem_from_dptensor(
74+
// ": 1D array should be contiguous both as C-order and F-order");
75+
return dal::data_layout::row_major;
76+
}
77+
else {
78+
//if (!is_c_cont || !is_f_cont) report_problem_from_dptensor(
79+
// ": 2D array should be contiguous at least by one axis");
80+
return is_c_cont ? dal::data_layout::row_major : dal::data_layout::column_major;
81+
}
82+
}
83+
84+
template <typename Type>
85+
dal::table convert_to_homogen_impl(py::object obj, dpctl::tensor::usm_ndarray& tensor) {
86+
const dpctl::tensor::usm_ndarray* const ptr = &tensor;
87+
const auto deleter = [obj](const Type*) {
88+
obj.dec_ref();
89+
};
90+
const auto [r_count, c_count] = get_dptensor_shape(tensor);
91+
const auto layout = get_dptensor_layout(tensor);
92+
const auto* data = tensor.get_data<Type>();
93+
const auto queue = tensor.get_queue();
94+
95+
auto res = dal::homogen_table(queue,
96+
data,
97+
r_count,
98+
c_count, //
99+
deleter,
100+
std::vector<sycl::event>{},
101+
layout);
102+
103+
obj.inc_ref();
104+
105+
return res;
106+
}
107+
108+
dal::table convert_from_dptensor(py::object obj) {
109+
auto tensor = pybind11::cast<dpctl::tensor::usm_ndarray>(obj);
110+
111+
const auto type = tensor.get_typenum();
112+
113+
dal::table res{};
114+
115+
#define MAKE_HOMOGEN_TABLE(CType) \
116+
res = convert_to_homogen_impl<CType>(obj, tensor);
117+
118+
SET_NPY_FEATURE(type,
119+
MAKE_HOMOGEN_TABLE, //
120+
report_problem_from_dptensor(": unknown data type"));
121+
122+
#undef MAKE_HOMOGEN_TABLE
123+
124+
return res;
125+
}
126+
127+
void report_problem_to_dptensor(const char* clarification) {
128+
constexpr const char* const base_message = "Unable to convert to dptensor";
129+
130+
std::string message{ base_message };
131+
message += std::string{ clarification };
132+
throw std::runtime_error{ message };
133+
}
134+
135+
// TODO:
136+
// return type.
137+
std::string get_npy_typestr(const dal::data_type dtype) {
138+
switch (dtype) {
139+
case dal::data_type::float32: {
140+
return "<f4";
141+
break;
142+
}
143+
case dal::data_type::float64: {
144+
return "<f8";
145+
break;
146+
}
147+
case dal::data_type::int32: {
148+
return "<i4";
149+
break;
150+
}
151+
case dal::data_type::int64: {
152+
return "<i8";
153+
break;
154+
}
155+
default: report_problem_to_dptensor(": unknown data type");
156+
};
157+
}
158+
159+
py::tuple get_npy_strides(const dal::data_layout& data_layout,
160+
npy_intp row_count,
161+
npy_intp column_count) {
162+
if (data_layout == dal::data_layout::unknown) {
163+
report_problem_to_dptensor(": unknown data layout");
164+
}
165+
py::tuple strides;
166+
if (data_layout == dal::data_layout::row_major) {
167+
strides = py::make_tuple(column_count, 1l);
168+
}
169+
else {
170+
strides = py::make_tuple(1l, row_count);
171+
}
172+
return strides;
173+
}
174+
175+
py::dict construct_sua_iface(const dal::table& input) {
176+
const auto kind = input.get_kind();
177+
if (kind != dal::homogen_table::kind())
178+
report_problem_to_dptensor(": only homogen tables are supported");
179+
180+
const auto& homogen_input = reinterpret_cast<const dal::homogen_table&>(input);
181+
const dal::data_type dtype = homogen_input.get_metadata().get_data_type(0);
182+
const dal::data_layout data_layout = homogen_input.get_data_layout();
183+
184+
npy_intp row_count = dal::detail::integral_cast<npy_intp>(
185+
homogen_input.get_row_count());
186+
npy_intp column_count = dal::detail::integral_cast<npy_intp>(
187+
homogen_input.get_column_count());
188+
189+
// need "version", "data", "shape", "typestr", "syclobj"
190+
py::tuple shape = py::make_tuple(row_count, column_count);
191+
py::list data_entry(2);
192+
193+
auto bytes_array = dal::detail::get_original_data(homogen_input);
194+
if (!bytes_array.get_queue().has_value()) {
195+
report_problem_to_dptensor(": table has no queue");
196+
}
197+
auto queue = bytes_array.get_queue().value();
198+
199+
const bool is_mutable = bytes_array.has_mutable_data();
200+
201+
static_assert(sizeof(std::size_t) == sizeof(void*));
202+
data_entry[0] = is_mutable ? reinterpret_cast<std::size_t>(bytes_array.get_mutable_data())
203+
: reinterpret_cast<std::size_t>(bytes_array.get_data());
204+
data_entry[1] = is_mutable;
205+
206+
py::dict iface;
207+
iface["data"] = data_entry;
208+
iface["shape"] = shape;
209+
iface["strides"] = get_npy_strides(data_layout, row_count, column_count);
210+
// dpctl supports only version 1.
211+
iface["version"] = 1;
212+
iface["typestr"] = get_npy_typestr(dtype);
213+
iface["syclobj"] = py::cast(queue);
214+
215+
return iface;
216+
}
217+
218+
// We are using `__sycl_usm_array_interface__` attribute for constructing
219+
// dpctl tensor on python level.
220+
void define_sycl_usm_array_property(py::class_<dal::table>& table_obj) {
221+
table_obj.def_property_readonly("__sycl_usm_array_interface__", &construct_sua_iface);
222+
}
223+
224+
} // namespace oneapi::dal::python
225+
226+
#endif // ONEDAL_DPCTL_INTEGRATION
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*******************************************************************************
2+
* Copyright 2023 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#pragma once
18+
19+
#define PY_ARRAY_UNIQUE_SYMBOL ONEDAL_PY_ARRAY_API
20+
21+
#include <pybind11/pybind11.h>
22+
#include <numpy/arrayobject.h>
23+
24+
#include "oneapi/dal/table/common.hpp"
25+
26+
namespace oneapi::dal::python {
27+
28+
namespace py = pybind11;
29+
30+
dal::table convert_from_dptensor(py::object obj);
31+
py::dict construct_sua_iface(const dal::table& input);
32+
33+
void define_sycl_usm_array_property(py::class_<dal::table>& t);
34+
35+
} // namespace oneapi::dal::python

onedal/datatypes/table.cpp

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
#include "oneapi/dal/table/homogen.hpp"
1818
#include "oneapi/dal/table/detail/csr.hpp"
1919

20+
#ifdef ONEDAL_DPCTL_INTEGRATION
21+
#include "onedal/datatypes/data_conversion_dpctl.hpp"
22+
#endif // ONEDAL_DPCTL_INTEGRATION
23+
2024
#include "onedal/datatypes/data_conversion.hpp"
2125
#include "onedal/common/pybind11_helpers.hpp"
2226

@@ -32,23 +36,27 @@ static void* init_numpy() {
3236
ONEDAL_PY_INIT_MODULE(table) {
3337
init_numpy();
3438

35-
py::class_<table>(m, "table")
36-
.def(py::init())
37-
.def_property_readonly("has_data", &table::has_data)
38-
.def_property_readonly("column_count", &table::get_column_count)
39-
.def_property_readonly("row_count", &table::get_row_count)
40-
.def_property_readonly("kind", [](const table& t) {
41-
if (t.get_kind() == 0) { // TODO: expose empty table kind
42-
return "empty";
43-
}
44-
if (t.get_kind() == homogen_table::kind()) {
45-
return "homogen";
46-
}
47-
if (t.get_kind() == detail::csr_table::kind()) {
48-
return "csr";
49-
}
50-
return "unknown";
51-
});
39+
py::class_<table> table_obj(m, "table");
40+
table_obj.def(py::init());
41+
table_obj.def_property_readonly("has_data", &table::has_data);
42+
table_obj.def_property_readonly("column_count", &table::get_column_count);
43+
table_obj.def_property_readonly("row_count", &table::get_row_count);
44+
table_obj.def_property_readonly("kind", [](const table& t) {
45+
if (t.get_kind() == 0) { // TODO: expose empty table kind
46+
return "empty";
47+
}
48+
if (t.get_kind() == homogen_table::kind()) {
49+
return "homogen";
50+
}
51+
if (t.get_kind() == detail::csr_table::kind()) {
52+
return "csr";
53+
}
54+
return "unknown";
55+
});
56+
57+
#ifdef ONEDAL_DPCTL_INTEGRATION
58+
define_sycl_usm_array_property(table_obj);
59+
#endif // ONEDAL_DPCTL_INTEGRATION
5260

5361
m.def("to_table", [](py::object obj) {
5462
auto* obj_ptr = obj.ptr();
@@ -59,6 +67,13 @@ ONEDAL_PY_INIT_MODULE(table) {
5967
auto* obj_ptr = convert_to_pyobject(t);
6068
return obj_ptr;
6169
});
70+
71+
#ifdef ONEDAL_DPCTL_INTEGRATION
72+
m.def("dpctl_to_table", [](py::object obj) {
73+
return convert_from_dptensor(obj);
74+
});
75+
76+
#endif // ONEDAL_DPCTL_INTEGRATION
6277
}
6378

6479
} // namespace oneapi::dal::python

0 commit comments

Comments
 (0)