Copy framework of incoming nb::ndarray #813
Answered
by
wjakob
RaulPPelaez
asked this question in
Q&A
-
I want to create a C++ extension that will return an ndarray with the same framework as the received one. So that I can do something like: from my_extension import foo
import torch
a = torch.ones(10)
b = foo(a)
assert type(a) == type(b)
import numpy as np
a = np.ones(10)
b = foo(a)
assert type(a) == type(b) With an extension similar to one of the examples: #include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>
#include <vector>
#include <iostream>
namespace nb = nanobind;
auto create_array(size_t N = 0) {
std::vector<float> *v =
new std::vector<float>(N , 0);
nb::capsule deleter(v, [](void *v) noexcept {
delete static_cast<std::vector<float> *>(v);
});
return nb::ndarray<nb::numpy, float>(v->data(), {N},
std::move(deleter));
}
auto process_array_generic(nb::ndarray<float> &a) {
printf("Array data pointer : %p\n", a.data());
printf("Array dimension : %zu\n", a.ndim());
printf("Device ID = %u (cpu=%i, cuda=%i)\n", a.device_id(),
int(a.device_type() == nb::device::cpu::value),
int(a.device_type() == nb::device::cuda::value)
);
// I do not know how to get the original framework from "a" in order to create a new one
auto array = create_array(a.shape(0));
return array;
}
NB_MODULE(my_extension, m) {
m.def("foo", &process_array_generic,
"Process an array compatible with the DLPack protocol",
nb::arg("arr"));
} AFAICT, the information about the framework is lost when C++ receives the buffer. I could return a DLPack and deal with the cast python side, but I wonder if there is a "sane" way to achieve it C++-side. |
Beta Was this translation helpful? Give feedback.
Answered by
wjakob
Dec 13, 2024
Replies: 1 comment 1 reply
-
This is my current approach to identify the framework of an void process_array_generic(nb::ndarray<float> &a) {
nb::object obj = nb::find(a);
std::string tn =nb::str(nb::getattr(obj.release(), "__class__")).c_str();
bool isnp = tn.find("numpy") != std::string::npos;
bool istorch = tn.find("torch") != std::string::npos;
bool iscupy = tn.find("cupy") != std::string::npos;
bool isjax = tn.find("jax") != std::string::npos;
} |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nanobind doesn't support this kind of advanced use case out of the box. There are currently no type casters that somehow preserve metadata when going from an argument to a return value.
My suggestion would be that you look into making your own type caster to encapsulate this logic so that you don't have to write it over and over again. If you are never mixing frameworks in a single function call, it might be easiest to store the needed framework type in a global variable or TLS variable.