Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions onedal/datatypes/numpy/data_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,10 @@ dal::table convert_to_table(py::object inp_obj, py::object queue, bool recursed)
return res;
}

static void free_capsule(PyObject *cap) {
template <class T>
void free_capsule(PyObject *cap) {
// TODO: check safe cast
dal::base *stored_array = static_cast<dal::base *>(PyCapsule_GetPointer(cap, NULL));
dal::array<T> *stored_array = static_cast<dal::array<T> *>(PyCapsule_GetPointer(cap, NULL));
if (stored_array) {
delete stored_array;
}
Expand Down Expand Up @@ -304,7 +305,7 @@ static PyObject *convert_to_numpy_impl(
throw std::invalid_argument("Conversion to numpy array failed");

void *opaque_value = static_cast<void *>(new dal::array<T>(host_array));
PyObject *cap = PyCapsule_New(opaque_value, NULL, free_capsule);
PyObject *cap = PyCapsule_New(opaque_value, NULL, free_capsule<T>);
PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), cap);
return obj;
}
Expand Down
4 changes: 2 additions & 2 deletions onedal/datatypes/table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ ONEDAL_PY_INIT_MODULE(table) {
return numpy::convert_to_table(obj, queue);
});

m.def("from_table", [](const dal::table& t) -> py::handle {
m.def("from_table", [](const dal::table& t) -> py::object {
auto* obj_ptr = numpy::convert_to_pyobject(t);
return obj_ptr;
return pybind11::reinterpret_steal<py::object>(obj_ptr);
});
m.def("dlpack_memory_order", &dlpack::dlpack_memory_order);
py::enum_<DLDeviceType>(m, "DLDeviceType")
Expand Down
Loading