diff --git a/onedal/datatypes/numpy/data_conversion.cpp b/onedal/datatypes/numpy/data_conversion.cpp index 3f308148ce..7f1cb47c5c 100644 --- a/onedal/datatypes/numpy/data_conversion.cpp +++ b/onedal/datatypes/numpy/data_conversion.cpp @@ -269,15 +269,16 @@ dal::table convert_to_table(py::object inp_obj, py::object queue, bool recursed) return res; } -static void free_capsule(PyObject *cap) { +template +void free_capsule(PyObject *cap) { // TODO: check safe cast - dal::base *stored_array = static_cast(PyCapsule_GetPointer(cap, NULL)); + dal::array *stored_array = static_cast *>(PyCapsule_GetPointer(cap, NULL)); if (stored_array) { delete stored_array; } } -template +template static PyObject *convert_to_numpy_impl( const dal::array &array, std::int64_t row_count, @@ -304,7 +305,7 @@ static PyObject *convert_to_numpy_impl( throw std::invalid_argument("Conversion to numpy array failed"); void *opaque_value = static_cast(new dal::array(host_array)); - PyObject *cap = PyCapsule_New(opaque_value, NULL, free_capsule); + PyObject *cap = PyCapsule_New(opaque_value, NULL, free_capsule); PyArray_SetBaseObject(reinterpret_cast(obj), cap); return obj; } @@ -420,18 +421,29 @@ PyObject *convert_to_pyobject(const dal::table &input) { const auto &homogen_input = static_cast(input); const dal::data_type dtype = homogen_input.get_metadata().get_data_type(0); -#define MAKE_NYMPY_FROM_HOMOGEN(NpType) \ - { \ - auto bytes_array = dal::detail::get_original_data(homogen_input); \ - res = convert_to_numpy_impl(bytes_array, \ - homogen_input.get_row_count(), \ - homogen_input.get_column_count(), \ - homogen_input.get_data_layout()); \ +#define MAKE_NUMPY_FROM_HOMOGEN(NpType, T) \ + { \ + auto bytes_array = dal::detail::get_original_data(homogen_input); \ + dal::array typed_array; \ + if (bytes_array.has_mutable_data()) { \ + typed_array.reset(bytes_array, \ + reinterpret_cast(bytes_array.get_mutable_data()), \ + bytes_array.get_count() / sizeof(T)); \ + } \ + else { \ + typed_array.reset(bytes_array, \ + reinterpret_cast(bytes_array.get_data()), \ + bytes_array.get_count() / sizeof(T)); \ + } \ + res = convert_to_numpy_impl(typed_array, \ + homogen_input.get_row_count(), \ + homogen_input.get_column_count(), \ + homogen_input.get_data_layout()); \ } - SET_CTYPE_NPY_FROM_DAL_TYPE(dtype, - MAKE_NYMPY_FROM_HOMOGEN, - throw std::invalid_argument("Unable to convert numpy object")); -#undef MAKE_NYMPY_FROM_HOMOGEN + SET_CTYPES_NPY_FROM_DAL_TYPE(dtype, + MAKE_NUMPY_FROM_HOMOGEN, + throw std::invalid_argument("Unable to convert numpy object")); +#undef MAKE_NUMPY_FROM_HOMOGEN } else if (input.get_kind() == csr_table_t::kind()) { const auto &csr_input = static_cast(input); diff --git a/onedal/datatypes/table.cpp b/onedal/datatypes/table.cpp index 12f51dca81..8824fb85af 100644 --- a/onedal/datatypes/table.cpp +++ b/onedal/datatypes/table.cpp @@ -103,9 +103,9 @@ ONEDAL_PY_INIT_MODULE(table) { return numpy::convert_to_table(obj, queue); }); - m.def("from_table", [](const dal::table& t) -> py::handle { + m.def("from_table", [](const dal::table& t) -> py::object { auto* obj_ptr = numpy::convert_to_pyobject(t); - return obj_ptr; + return py::reinterpret_steal(obj_ptr); }); m.def("dlpack_memory_order", &dlpack::dlpack_memory_order); py::enum_(m, "DLDeviceType")