Skip to content

Commit 1039951

Browse files
authored
Add versioning for oneDAL's csr_table class
Use dal::detail::csr_table for oneDAL 2023.1 and older; use dal::csr_table starting from oneDAL 2023.1.1.
1 parent df270a3 commit 1039951

File tree

2 files changed

+117
-26
lines changed

2 files changed

+117
-26
lines changed

onedal/datatypes/data_conversion.cpp

Lines changed: 103 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,26 @@
2020
#include <string>
2121

2222
#include "oneapi/dal/table/homogen.hpp"
23-
#include "oneapi/dal/table/detail/csr.hpp"
2423
#include "oneapi/dal/table/detail/homogen_utils.hpp"
2524

2625
#include "onedal/datatypes/data_conversion.hpp"
2726
#include "onedal/datatypes/numpy_helpers.hpp"
27+
#include "onedal/version.hpp"
28+
29+
#if ONEDAL_VERSION <= 20230100
30+
#include "oneapi/dal/table/detail/csr.hpp"
31+
#else
32+
#include "oneapi/dal/table/csr.hpp"
33+
#endif
2834

2935
namespace oneapi::dal::python {
3036

37+
#if ONEDAL_VERSION <= 20230100
38+
typedef oneapi::dal::detail::csr_table csr_table_t;
39+
#else
40+
typedef oneapi::dal::csr_table csr_table_t;
41+
#endif
42+
3143
template <typename T>
3244
static dal::array<T> transfer_to_host(const dal::array<T>& array) {
3345
#ifdef ONEDAL_DATA_PARALLEL
@@ -82,11 +94,11 @@ inline dal::homogen_table convert_to_homogen_impl(PyArrayObject *np_data) {
8294
}
8395

8496
template <typename T>
85-
inline dal::detail::csr_table convert_to_csr_impl(PyObject *py_data,
86-
PyObject *py_column_indices,
87-
PyObject *py_row_indices,
88-
std::int64_t row_count,
89-
std::int64_t column_count) {
97+
inline csr_table_t convert_to_csr_impl(PyObject* py_data,
98+
PyObject* py_column_indices,
99+
PyObject* py_row_indices,
100+
std::int64_t row_count,
101+
std::int64_t column_count) {
90102
PyArrayObject *np_data = reinterpret_cast<PyArrayObject *>(py_data);
91103
PyArrayObject *np_column_indices = reinterpret_cast<PyArrayObject *>(py_column_indices);
92104
PyArrayObject *np_row_indices = reinterpret_cast<PyArrayObject *>(py_row_indices);
@@ -115,12 +127,18 @@ inline dal::detail::csr_table convert_to_csr_impl(PyObject *py_data,
115127
const T *data_pointer = static_cast<T *>(array_data(np_data));
116128
const std::int64_t data_count = static_cast<std::int64_t>(array_size(np_data, 0));
117129

118-
auto res_table = dal::detail::csr_table(
119-
dal::array<T>(data_pointer, data_count, [np_data](const T* data) { Py_DECREF(np_data); }),
120-
column_indices_one_based,
121-
row_indices_one_based,
122-
row_count,
123-
column_count);
130+
auto res_table = csr_table_t(dal::array<T>(data_pointer,
131+
data_count,
132+
[np_data](const T*) {
133+
Py_DECREF(np_data);
134+
}),
135+
column_indices_one_based,
136+
row_indices_one_based,
137+
#if ONEDAL_VERSION <= 20230100
138+
// row_count parameter is present in csr_table's constructor only in older versions of oneDAL
139+
row_count,
140+
#endif
141+
column_count);
124142

125143
// we need to increment the ref-count as we use the input array in-place
126144
Py_INCREF(np_data);
@@ -232,46 +250,107 @@ static PyObject *convert_to_numpy_impl(const dal::array<T> &array,
232250
return obj;
233251
}
234252

253+
#if ONEDAL_VERSION <= 20230100
254+
255+
// dal::detail::csr_table class is valid
256+
// only one-based indeices are supported
235257
template <int NpType, typename T>
236-
static PyObject *convert_to_py_from_csr_impl(const detail::csr_table &table) {
237-
PyObject *result = PyTuple_New(3);
258+
static PyObject* convert_to_py_from_csr_impl(const detail::csr_table& table) {
259+
PyObject* result = PyTuple_New(3);
238260
const std::int64_t rows_indices_count = table.get_row_count() + 1;
239261

240-
const std::int64_t *row_indices_one_based = table.get_row_indices();
241-
std::uint64_t *row_indices_zero_based_data =
262+
const std::int64_t* row_indices_one_based = table.get_row_indices();
263+
std::uint64_t* row_indices_zero_based_data =
242264
detail::host_allocator<std::uint64_t>().allocate(rows_indices_count);
243265
for (std::int64_t i = 0; i < rows_indices_count; ++i)
244266
row_indices_zero_based_data[i] = row_indices_one_based[i] - 1;
245267

246268
auto row_indices_zero_based_array =
247269
dal::array<std::uint64_t>::wrap(row_indices_zero_based_data, rows_indices_count);
248-
PyObject *py_row =
270+
PyObject* py_row =
249271
convert_to_numpy_impl<NPY_UINT64, std::uint64_t>(row_indices_zero_based_array,
250272
rows_indices_count);
251273
PyTuple_SetItem(result, 2, py_row);
252274

253275
const std::int64_t non_zero_count = row_indices_zero_based_data[rows_indices_count - 1];
254-
const T *data = reinterpret_cast<const T *>(table.get_data());
276+
const T* data = reinterpret_cast<const T*>(table.get_data());
255277
auto data_array = dal::array<T>::wrap(data, non_zero_count);
256278

257-
PyObject *py_data = convert_to_numpy_impl<NpType, T>(data_array, non_zero_count);
279+
PyObject* py_data = convert_to_numpy_impl<NpType, T>(data_array, non_zero_count);
258280
PyTuple_SetItem(result, 0, py_data);
259281

260-
const std::int64_t *column_indices_one_based = table.get_column_indices();
261-
std::uint64_t *column_indices_zero_based_data =
282+
const std::int64_t* column_indices_one_based = table.get_column_indices();
283+
std::uint64_t* column_indices_zero_based_data =
262284
detail::host_allocator<std::uint64_t>().allocate(non_zero_count);
263285
for (std::int64_t i = 0; i < non_zero_count; ++i)
264286
column_indices_zero_based_data[i] = column_indices_one_based[i] - 1;
265287

266288
auto column_indices_zero_based_array =
267289
dal::array<std::uint64_t>::wrap(column_indices_zero_based_data, non_zero_count);
268-
PyObject *py_col =
290+
PyObject* py_col =
269291
convert_to_numpy_impl<NPY_UINT64, std::uint64_t>(column_indices_zero_based_array,
270292
non_zero_count);
271293
PyTuple_SetItem(result, 1, py_col);
272294
return result;
273295
}
274296

297+
#else // ONEDAL_VERSION > 20230100
298+
299+
// dal::csr_table class is valid
300+
// zero- and one-based indeices are supported
301+
template <int NpType, typename T>
302+
static PyObject* convert_to_py_from_csr_impl(const csr_table& table) {
303+
PyObject* result = PyTuple_New(3);
304+
const std::int64_t rows_indices_count = table.get_row_count() + 1;
305+
const std::int64_t non_zero_count = table.get_non_zero_count();
306+
const std::int64_t* row_offsets = table.get_row_offsets();
307+
const std::int64_t* column_indices = table.get_column_indices();
308+
309+
std::uint64_t* column_indices_zero_based_data = nullptr;
310+
std::uint64_t* row_offsets_zero_based_data = nullptr;
311+
312+
if (table.get_indexing() == sparse_indexing::zero_based) {
313+
column_indices_zero_based_data =
314+
const_cast<std::uint64_t*>(reinterpret_cast<const std::uint64_t*>(column_indices));
315+
row_offsets_zero_based_data =
316+
const_cast<std::uint64_t*>(reinterpret_cast<const std::uint64_t*>(row_offsets));
317+
}
318+
else { // table.get_indexing() == sparse_indexing::one_based
319+
column_indices_zero_based_data =
320+
detail::host_allocator<std::uint64_t>().allocate(non_zero_count);
321+
row_offsets_zero_based_data =
322+
detail::host_allocator<std::uint64_t>().allocate(rows_indices_count);
323+
324+
for (std::int64_t i = 0; i < non_zero_count; ++i)
325+
column_indices_zero_based_data[i] = column_indices[i] - 1;
326+
327+
for (std::int64_t i = 0; i < rows_indices_count; ++i)
328+
row_offsets_zero_based_data[i] = row_offsets[i] - 1;
329+
}
330+
331+
const T* data = table.get_data<T>();
332+
auto data_array = dal::array<T>::wrap(data, non_zero_count);
333+
334+
PyObject* py_data = convert_to_numpy_impl<NpType, T>(data_array, non_zero_count);
335+
PyTuple_SetItem(result, 0, py_data);
336+
337+
auto column_indices_zero_based_array =
338+
dal::array<std::uint64_t>::wrap(column_indices_zero_based_data, non_zero_count);
339+
PyObject* py_col =
340+
convert_to_numpy_impl<NPY_UINT64, std::uint64_t>(column_indices_zero_based_array,
341+
non_zero_count);
342+
PyTuple_SetItem(result, 1, py_col);
343+
auto row_indices_zero_based_array =
344+
dal::array<std::uint64_t>::wrap(row_offsets_zero_based_data, rows_indices_count);
345+
PyObject* py_row =
346+
convert_to_numpy_impl<NPY_UINT64, std::uint64_t>(row_indices_zero_based_array,
347+
rows_indices_count);
348+
PyTuple_SetItem(result, 2, py_row);
349+
return result;
350+
}
351+
352+
#endif // ONEDAL_VERSION <= 20230100
353+
275354
PyObject *convert_to_pyobject(const dal::table &input) {
276355
PyObject *res = nullptr;
277356
if (!input.has_data()) {
@@ -301,8 +380,8 @@ PyObject *convert_to_pyobject(const dal::table &input) {
301380
"Output oneDAL table doesn't have row major format for homogen table");
302381
}
303382
}
304-
else if (input.get_kind() == dal::detail::csr_table::kind()) {
305-
const auto &csr_input = static_cast<const detail::csr_table &>(input);
383+
else if (input.get_kind() == csr_table_t::kind()) {
384+
const auto &csr_input = static_cast<const csr_table_t &>(input);
306385
const dal::data_type dtype = csr_input.get_metadata().get_data_type(0);
307386
#define MAKE_PY_FROM_CSR(NpType, T) \
308387
{ res = convert_to_py_from_csr_impl<NpType, T>(csr_input); }

onedal/datatypes/table.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,31 @@
1515
*******************************************************************************/
1616

1717
#include "oneapi/dal/table/homogen.hpp"
18-
#include "oneapi/dal/table/detail/csr.hpp"
1918

2019
#ifdef ONEDAL_DPCTL_INTEGRATION
2120
#include "onedal/datatypes/data_conversion_dpctl.hpp"
2221
#endif // ONEDAL_DPCTL_INTEGRATION
2322

2423
#include "onedal/datatypes/data_conversion.hpp"
2524
#include "onedal/common/pybind11_helpers.hpp"
25+
#include "onedal/version.hpp"
26+
27+
#if ONEDAL_VERSION <= 20230100
28+
#include "oneapi/dal/table/detail/csr.hpp"
29+
#else
30+
#include "oneapi/dal/table/csr.hpp"
31+
#endif
2632

2733
namespace py = pybind11;
2834

2935
namespace oneapi::dal::python {
3036

37+
#if ONEDAL_VERSION <= 20230100
38+
typedef oneapi::dal::detail::csr_table csr_table_t;
39+
#else
40+
typedef oneapi::dal::csr_table csr_table_t;
41+
#endif
42+
3143
static void* init_numpy() {
3244
import_array();
3345
return nullptr;
@@ -48,7 +60,7 @@ ONEDAL_PY_INIT_MODULE(table) {
4860
if (t.get_kind() == homogen_table::kind()) {
4961
return "homogen";
5062
}
51-
if (t.get_kind() == detail::csr_table::kind()) {
63+
if (t.get_kind() == csr_table_t::kind()) {
5264
return "csr";
5365
}
5466
return "unknown";

0 commit comments

Comments
 (0)