|
20 | 20 | #include <string> |
21 | 21 |
|
22 | 22 | #include "oneapi/dal/table/homogen.hpp" |
23 | | -#include "oneapi/dal/table/detail/csr.hpp" |
24 | 23 | #include "oneapi/dal/table/detail/homogen_utils.hpp" |
25 | 24 |
|
26 | 25 | #include "onedal/datatypes/data_conversion.hpp" |
27 | 26 | #include "onedal/datatypes/numpy_helpers.hpp" |
| 27 | +#include "onedal/version.hpp" |
| 28 | + |
| 29 | +#if ONEDAL_VERSION <= 20230100 |
| 30 | + #include "oneapi/dal/table/detail/csr.hpp" |
| 31 | +#else |
| 32 | + #include "oneapi/dal/table/csr.hpp" |
| 33 | +#endif |
28 | 34 |
|
29 | 35 | namespace oneapi::dal::python { |
30 | 36 |
|
| 37 | +#if ONEDAL_VERSION <= 20230100 |
| 38 | +typedef oneapi::dal::detail::csr_table csr_table_t; |
| 39 | +#else |
| 40 | +typedef oneapi::dal::csr_table csr_table_t; |
| 41 | +#endif |
| 42 | + |
31 | 43 | template <typename T> |
32 | 44 | static dal::array<T> transfer_to_host(const dal::array<T>& array) { |
33 | 45 | #ifdef ONEDAL_DATA_PARALLEL |
@@ -82,11 +94,11 @@ inline dal::homogen_table convert_to_homogen_impl(PyArrayObject *np_data) { |
82 | 94 | } |
83 | 95 |
|
84 | 96 | template <typename T> |
85 | | -inline dal::detail::csr_table convert_to_csr_impl(PyObject *py_data, |
86 | | - PyObject *py_column_indices, |
87 | | - PyObject *py_row_indices, |
88 | | - std::int64_t row_count, |
89 | | - std::int64_t column_count) { |
| 97 | +inline csr_table_t convert_to_csr_impl(PyObject* py_data, |
| 98 | + PyObject* py_column_indices, |
| 99 | + PyObject* py_row_indices, |
| 100 | + std::int64_t row_count, |
| 101 | + std::int64_t column_count) { |
90 | 102 | PyArrayObject *np_data = reinterpret_cast<PyArrayObject *>(py_data); |
91 | 103 | PyArrayObject *np_column_indices = reinterpret_cast<PyArrayObject *>(py_column_indices); |
92 | 104 | PyArrayObject *np_row_indices = reinterpret_cast<PyArrayObject *>(py_row_indices); |
@@ -115,12 +127,18 @@ inline dal::detail::csr_table convert_to_csr_impl(PyObject *py_data, |
115 | 127 | const T *data_pointer = static_cast<T *>(array_data(np_data)); |
116 | 128 | const std::int64_t data_count = static_cast<std::int64_t>(array_size(np_data, 0)); |
117 | 129 |
|
118 | | - auto res_table = dal::detail::csr_table( |
119 | | - dal::array<T>(data_pointer, data_count, [np_data](const T* data) { Py_DECREF(np_data); }), |
120 | | - column_indices_one_based, |
121 | | - row_indices_one_based, |
122 | | - row_count, |
123 | | - column_count); |
| 130 | + auto res_table = csr_table_t(dal::array<T>(data_pointer, |
| 131 | + data_count, |
| 132 | + [np_data](const T*) { |
| 133 | + Py_DECREF(np_data); |
| 134 | + }), |
| 135 | + column_indices_one_based, |
| 136 | + row_indices_one_based, |
| 137 | +#if ONEDAL_VERSION <= 20230100 |
| 138 | +// row_count parameter is present in csr_table's constructor only in older versions of oneDAL |
| 139 | + row_count, |
| 140 | +#endif |
| 141 | + column_count); |
124 | 142 |
|
125 | 143 | // we need to increment the ref-count as we use the input array in-place |
126 | 144 | Py_INCREF(np_data); |
@@ -232,46 +250,107 @@ static PyObject *convert_to_numpy_impl(const dal::array<T> &array, |
232 | 250 | return obj; |
233 | 251 | } |
234 | 252 |
|
| 253 | +#if ONEDAL_VERSION <= 20230100 |
| 254 | + |
| 255 | +// dal::detail::csr_table class is valid |
| 256 | +// only one-based indeices are supported |
235 | 257 | template <int NpType, typename T> |
236 | | -static PyObject *convert_to_py_from_csr_impl(const detail::csr_table &table) { |
237 | | - PyObject *result = PyTuple_New(3); |
| 258 | +static PyObject* convert_to_py_from_csr_impl(const detail::csr_table& table) { |
| 259 | + PyObject* result = PyTuple_New(3); |
238 | 260 | const std::int64_t rows_indices_count = table.get_row_count() + 1; |
239 | 261 |
|
240 | | - const std::int64_t *row_indices_one_based = table.get_row_indices(); |
241 | | - std::uint64_t *row_indices_zero_based_data = |
| 262 | + const std::int64_t* row_indices_one_based = table.get_row_indices(); |
| 263 | + std::uint64_t* row_indices_zero_based_data = |
242 | 264 | detail::host_allocator<std::uint64_t>().allocate(rows_indices_count); |
243 | 265 | for (std::int64_t i = 0; i < rows_indices_count; ++i) |
244 | 266 | row_indices_zero_based_data[i] = row_indices_one_based[i] - 1; |
245 | 267 |
|
246 | 268 | auto row_indices_zero_based_array = |
247 | 269 | dal::array<std::uint64_t>::wrap(row_indices_zero_based_data, rows_indices_count); |
248 | | - PyObject *py_row = |
| 270 | + PyObject* py_row = |
249 | 271 | convert_to_numpy_impl<NPY_UINT64, std::uint64_t>(row_indices_zero_based_array, |
250 | 272 | rows_indices_count); |
251 | 273 | PyTuple_SetItem(result, 2, py_row); |
252 | 274 |
|
253 | 275 | const std::int64_t non_zero_count = row_indices_zero_based_data[rows_indices_count - 1]; |
254 | | - const T *data = reinterpret_cast<const T *>(table.get_data()); |
| 276 | + const T* data = reinterpret_cast<const T*>(table.get_data()); |
255 | 277 | auto data_array = dal::array<T>::wrap(data, non_zero_count); |
256 | 278 |
|
257 | | - PyObject *py_data = convert_to_numpy_impl<NpType, T>(data_array, non_zero_count); |
| 279 | + PyObject* py_data = convert_to_numpy_impl<NpType, T>(data_array, non_zero_count); |
258 | 280 | PyTuple_SetItem(result, 0, py_data); |
259 | 281 |
|
260 | | - const std::int64_t *column_indices_one_based = table.get_column_indices(); |
261 | | - std::uint64_t *column_indices_zero_based_data = |
| 282 | + const std::int64_t* column_indices_one_based = table.get_column_indices(); |
| 283 | + std::uint64_t* column_indices_zero_based_data = |
262 | 284 | detail::host_allocator<std::uint64_t>().allocate(non_zero_count); |
263 | 285 | for (std::int64_t i = 0; i < non_zero_count; ++i) |
264 | 286 | column_indices_zero_based_data[i] = column_indices_one_based[i] - 1; |
265 | 287 |
|
266 | 288 | auto column_indices_zero_based_array = |
267 | 289 | dal::array<std::uint64_t>::wrap(column_indices_zero_based_data, non_zero_count); |
268 | | - PyObject *py_col = |
| 290 | + PyObject* py_col = |
269 | 291 | convert_to_numpy_impl<NPY_UINT64, std::uint64_t>(column_indices_zero_based_array, |
270 | 292 | non_zero_count); |
271 | 293 | PyTuple_SetItem(result, 1, py_col); |
272 | 294 | return result; |
273 | 295 | } |
274 | 296 |
|
| 297 | +#else // ONEDAL_VERSION > 20230100 |
| 298 | + |
| 299 | +// dal::csr_table class is valid |
| 300 | +// zero- and one-based indeices are supported |
| 301 | +template <int NpType, typename T> |
| 302 | +static PyObject* convert_to_py_from_csr_impl(const csr_table& table) { |
| 303 | + PyObject* result = PyTuple_New(3); |
| 304 | + const std::int64_t rows_indices_count = table.get_row_count() + 1; |
| 305 | + const std::int64_t non_zero_count = table.get_non_zero_count(); |
| 306 | + const std::int64_t* row_offsets = table.get_row_offsets(); |
| 307 | + const std::int64_t* column_indices = table.get_column_indices(); |
| 308 | + |
| 309 | + std::uint64_t* column_indices_zero_based_data = nullptr; |
| 310 | + std::uint64_t* row_offsets_zero_based_data = nullptr; |
| 311 | + |
| 312 | + if (table.get_indexing() == sparse_indexing::zero_based) { |
| 313 | + column_indices_zero_based_data = |
| 314 | + const_cast<std::uint64_t*>(reinterpret_cast<const std::uint64_t*>(column_indices)); |
| 315 | + row_offsets_zero_based_data = |
| 316 | + const_cast<std::uint64_t*>(reinterpret_cast<const std::uint64_t*>(row_offsets)); |
| 317 | + } |
| 318 | + else { // table.get_indexing() == sparse_indexing::one_based |
| 319 | + column_indices_zero_based_data = |
| 320 | + detail::host_allocator<std::uint64_t>().allocate(non_zero_count); |
| 321 | + row_offsets_zero_based_data = |
| 322 | + detail::host_allocator<std::uint64_t>().allocate(rows_indices_count); |
| 323 | + |
| 324 | + for (std::int64_t i = 0; i < non_zero_count; ++i) |
| 325 | + column_indices_zero_based_data[i] = column_indices[i] - 1; |
| 326 | + |
| 327 | + for (std::int64_t i = 0; i < rows_indices_count; ++i) |
| 328 | + row_offsets_zero_based_data[i] = row_offsets[i] - 1; |
| 329 | + } |
| 330 | + |
| 331 | + const T* data = table.get_data<T>(); |
| 332 | + auto data_array = dal::array<T>::wrap(data, non_zero_count); |
| 333 | + |
| 334 | + PyObject* py_data = convert_to_numpy_impl<NpType, T>(data_array, non_zero_count); |
| 335 | + PyTuple_SetItem(result, 0, py_data); |
| 336 | + |
| 337 | + auto column_indices_zero_based_array = |
| 338 | + dal::array<std::uint64_t>::wrap(column_indices_zero_based_data, non_zero_count); |
| 339 | + PyObject* py_col = |
| 340 | + convert_to_numpy_impl<NPY_UINT64, std::uint64_t>(column_indices_zero_based_array, |
| 341 | + non_zero_count); |
| 342 | + PyTuple_SetItem(result, 1, py_col); |
| 343 | + auto row_indices_zero_based_array = |
| 344 | + dal::array<std::uint64_t>::wrap(row_offsets_zero_based_data, rows_indices_count); |
| 345 | + PyObject* py_row = |
| 346 | + convert_to_numpy_impl<NPY_UINT64, std::uint64_t>(row_indices_zero_based_array, |
| 347 | + rows_indices_count); |
| 348 | + PyTuple_SetItem(result, 2, py_row); |
| 349 | + return result; |
| 350 | +} |
| 351 | + |
| 352 | +#endif // ONEDAL_VERSION <= 20230100 |
| 353 | + |
275 | 354 | PyObject *convert_to_pyobject(const dal::table &input) { |
276 | 355 | PyObject *res = nullptr; |
277 | 356 | if (!input.has_data()) { |
@@ -301,8 +380,8 @@ PyObject *convert_to_pyobject(const dal::table &input) { |
301 | 380 | "Output oneDAL table doesn't have row major format for homogen table"); |
302 | 381 | } |
303 | 382 | } |
304 | | - else if (input.get_kind() == dal::detail::csr_table::kind()) { |
305 | | - const auto &csr_input = static_cast<const detail::csr_table &>(input); |
| 383 | + else if (input.get_kind() == csr_table_t::kind()) { |
| 384 | + const auto &csr_input = static_cast<const csr_table_t &>(input); |
306 | 385 | const dal::data_type dtype = csr_input.get_metadata().get_data_type(0); |
307 | 386 | #define MAKE_PY_FROM_CSR(NpType, T) \ |
308 | 387 | { res = convert_to_py_from_csr_impl<NpType, T>(csr_input); } |
|
0 commit comments