diff --git a/docs/api_extra.rst b/docs/api_extra.rst index 83b17e99..b0560fed 100644 --- a/docs/api_extra.rst +++ b/docs/api_extra.rst @@ -1108,6 +1108,11 @@ convert into an equivalent representation in one of the following frameworks: Builtin Python ``memoryview`` for CPU-resident data. +.. cpp:class:: array_api + + An object that both implements the buffer protocol and also has the + ``__dlpack__`` and ``_dlpack_device__`` attributes. + Eigen convenience type aliases ------------------------------ diff --git a/docs/changelog.rst b/docs/changelog.rst index 62f6f118..0027ed05 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -22,6 +22,14 @@ Version TBD (not yet released) Clang-based Intel compiler). Continuous integration tests have been added to ensure compatibility with these compilers on an ongoing basis. +- The framework ``nb::array_api`` is now available to return an nd-array from + C++ to Python as an object that supports both the Python buffer protocol as + well as the DLPack methods ``__dlpack__`` and ``_dlpack_device__``. + Nanobind now supports importing and exporting nd-arrays via capsules that + contain the ``DLManagedTensorVersioned`` struct, which has a flag bit + indicating the nd-array is read-only. + (PR `#1175 `__). + Version 2.9.2 (Sep 4, 2025) --------------------------- diff --git a/docs/ndarray.rst b/docs/ndarray.rst index e62fcb37..d2f71dd3 100644 --- a/docs/ndarray.rst +++ b/docs/ndarray.rst @@ -275,12 +275,19 @@ desired Python type. - :cpp:class:`nb::tensorflow `: create a ``tensorflow.python.framework.ops.EagerTensor``. - :cpp:class:`nb::jax `: create a ``jaxlib.xla_extension.DeviceArray``. - :cpp:class:`nb::cupy `: create a ``cupy.ndarray``. +- :cpp:class:`nb::memview `: create a Python ``memoryview``. +- :cpp:class:`nb::array_api `: create an object that supports the + Python buffer protocol (i.e., is accepted as an argument to ``memoryview()``) + and also has the DLPack attributes ``__dlpack__`` and ``_dlpack_device__`` + (i.e., it is accepted as an argument to a framework's ``from_dlpack()`` + function). - No framework annotation. In this case, nanobind will create a raw Python ``dltensor`` `capsule `__ - representing the `DLPack `__ metadata. + representing the `DLPack `__ metadata of + a ``DLManagedTensor``. This annotation also affects the auto-generated docstring of the function, -which in this case becomes: +which in this example's case becomes: .. code-block:: python @@ -458,6 +465,21 @@ interpreted as follows: - :cpp:enumerator:`rv_policy::move` is unsupported and demoted to :cpp:enumerator:`rv_policy::copy`. +Note that when a copy is returned, the copy is made by the framework, not by +nanobind itself. +For example, ``numpy.array()`` is passed the keyword argument ``copy`` with +value ``True``, or the PyTorch tensor's ``clone()`` method is immediately +called to create the copy. +This design has a couple of advantages. +First, nanobind does not have a build-time dependency on the libraries and +frameworks (NumPy, PyTorch, CUDA, etc.) that would otherwise be necessary +to perform the copy. +Second, frameworks have the opportunity to optimize how the copy is created. +The copy is owned by the framework, so the framework can choose to use a custom +memory allocator, over-align the data, etc. based on the nd-array's size, +the specific CPU, GPU, or memory types detected, etc. + + .. _ndarray-temporaries: Returning temporaries @@ -643,26 +665,92 @@ support inter-framework data exchange, custom array types should implement the - `__dlpack__ `__ and - `__dlpack_device__ `__ -methods. This is easy thanks to the nd-array integration in nanobind. An example is shown below: +methods. +These, as well as the buffer protocol, are implemented in the object returned +by nanobind when specifying :cpp:class:`nb::array_api ` as the +framework template parameter. +For example: .. code-block:: cpp - nb::class_(m, "MyArray") - // ... - .def("__dlpack__", [](nb::kwargs kwargs) { - return nb::ndarray<>( /* ... */); - }) - .def("__dlpack_device__", []() { - return std::make_pair(nb::device::cpu::value, 0); - }); + class MyArray { + double* d; + public: + MyArray() { d = new double[5] { 0.0, 1.0, 2.0, 3.0, 4.0 }; } + ~MyArray() { delete[] d; } + double* data() const { return d; } + }; + + nb::class_(m, "MyArray") + .def(nb::init<>()) + .def("array_api", [](const MyArray& self) { + return nb::ndarray(self.data(), {5}); + }, nb::rv_policy::reference_internal); + +which can be used as follows: + +.. code-block:: pycon -Returning a raw :cpp:class:`nb::ndarray ` without framework annotation -will produce a DLPack capsule, which is what the interface expects. + >>> import my_extension + >>> ma = my_extension.MyArray() + >>> aa = ma.array_api() + >>> aa.__dlpack_device__() + (1, 0) + >>> import numpy as np + >>> x = np.from_dlpack(aa) + >>> x + array([0., 1., 2., 3., 4.]) + +The DLPack methods can also be provided for the class itself, by implementing +``__dlpack__()`` as a wrapper function. +For example, by adding the following lines to the binding: + +.. code-block:: cpp + + .def("__dlpack__", [](nb::pointer_and_handle self, + nb::kwargs kwargs) { + using array_api_t = nb::ndarray; + nb::object aa = nb::cast(array_api_t(self.p->data(), {5}), + nb::rv_policy::reference_internal, + self.h); + nb::object max = kwargs.get("max_version", nb::none()); + return aa.attr("__dlpack__")(nb::arg("max_version") = max); + }) + .def("__dlpack_device__", [](nb::handle /*self*/) { + return std::make_pair(nb::device::cpu::value, 0); + }) + +the class can be used as follows: + +.. code-block:: pycon + + >>> import my_extension + >>> ma = my_extension.MyArray() + >>> ma.__dlpack_device__() + (1, 0) + >>> import numpy as np + >>> y = np.from_dlpack(ma) + >>> y + array([0., 1., 2., 3., 4.]) + + +The ``kwargs`` argument in the implementation of ``__dlpack__`` above can be +used to support additional parameters (e.g., to allow the caller to request a +copy). Please see the DLPack documentation for details. + +The caller may or may not supply the keyword argument ``max_version``. +If it is not supplied or has the value ``None``, nanobind will return an +unversioned ``DLManagedTensor`` in a capsule named ``dltensor``. +If its value is a tuple of integers ``(major_version, minor_version)`` and the +major version is at least 1, nanobind will return a ``DLManagedTensorVersioned`` +in a capsule named ``dltensor_versioned``. +Nanobind ignores other keyword arguments. +In particular, it cannot transfer the array's data to another device (such as +a GPU), nor can it make a copy of the data. +A custom class (such as ``MyArray`` above) could provide such functionality. +Often, the caller framework takes care of copying and inter-device data +transfer and does not ask the producer, ``MyArray``, to perform them. -The ``kwargs`` argument can be used to provide additional parameters (for -example to request a copy), please see the DLPack documentation for details. -Note that nanobind does not yet implement the versioned DLPack protocol. The -version number should be ignored for now. Frequently asked questions -------------------------- @@ -708,7 +796,3 @@ be more restrictive. Presently supported dtypes include signed/unsigned integers, floating point values, complex numbers, and boolean values. Some :ref:`nonstandard arithmetic types ` can be supported as well. - -Nanobind can receive and return *read-only* arrays via the buffer protocol when -exhanging data with NumPy. The DLPack interface currently ignores this -annotation. diff --git a/include/nanobind/nb_defs.h b/include/nanobind/nb_defs.h index 20d239c2..f04e233a 100644 --- a/include/nanobind/nb_defs.h +++ b/include/nanobind/nb_defs.h @@ -209,7 +209,7 @@ X(const X &) = delete; \ X &operator=(const X &) = delete; -#define NB_MOD_STATE_SIZE 80 +#define NB_MOD_STATE_SIZE 96 // Helper macros to ensure macro arguments are expanded before token pasting/stringification #define NB_MODULE_IMPL(name, variable) NB_MODULE_IMPL2(name, variable) diff --git a/include/nanobind/nb_lib.h b/include/nanobind/nb_lib.h index 11fb8d18..8fd35a1a 100644 --- a/include/nanobind/nb_lib.h +++ b/include/nanobind/nb_lib.h @@ -12,8 +12,8 @@ NAMESPACE_BEGIN(NB_NAMESPACE) NAMESPACE_BEGIN(dlpack) // The version of DLPack that is supported by libnanobind -static constexpr uint32_t major_version = 0; -static constexpr uint32_t minor_version = 0; +static constexpr uint32_t major_version = 1; +static constexpr uint32_t minor_version = 1; // Forward declarations for types in ndarray.h (1) struct dltensor; @@ -289,7 +289,7 @@ NB_CORE PyObject *capsule_new(const void *ptr, const char *name, struct func_data_prelim_base; /// Create a Python function object for the given function record -NB_CORE PyObject *nb_func_new(const func_data_prelim_base *data) noexcept; +NB_CORE PyObject *nb_func_new(const func_data_prelim_base *f) noexcept; // ======================================================================== @@ -481,7 +481,7 @@ NB_CORE ndarray_handle *ndarray_import(PyObject *o, cleanup_list *cleanup) noexcept; // Describe a local ndarray object using a DLPack capsule -NB_CORE ndarray_handle *ndarray_create(void *value, size_t ndim, +NB_CORE ndarray_handle *ndarray_create(void *data, size_t ndim, const size_t *shape, PyObject *owner, const int64_t *strides, dlpack::dtype dtype, bool ro, diff --git a/include/nanobind/ndarray.h b/include/nanobind/ndarray.h index f71dc7e5..63802963 100644 --- a/include/nanobind/ndarray.h +++ b/include/nanobind/ndarray.h @@ -18,11 +18,16 @@ NAMESPACE_BEGIN(NB_NAMESPACE) -/// dlpack API/ABI data structures are part of a separate namespace +/// DLPack API/ABI data structures are part of a separate namespace. NAMESPACE_BEGIN(dlpack) enum class dtype_code : uint8_t { - Int = 0, UInt = 1, Float = 2, Bfloat = 4, Complex = 5, Bool = 6 + Int = 0, UInt = 1, Float = 2, Bfloat = 4, Complex = 5, Bool = 6, + Float8_E3M4 = 7, Float8_E4M3 = 8, Float8_E4M3B11FNUZ = 9, + Float8_E4M3FN = 10, Float8_E4M3FNUZ = 11, Float8_E5M2 = 12, + Float8_E5M2FNUZ = 13, Float8_E8M0FNU = 14, + Float6_E2M3FN = 15, Float6_E3M2FN = 16, + Float4_E2M1FN = 17 }; struct device { @@ -86,6 +91,7 @@ NB_FRAMEWORK(tensorflow, 3, "tensorflow.python.framework.ops.EagerTensor"); NB_FRAMEWORK(jax, 4, "jaxlib.xla_extension.DeviceArray"); NB_FRAMEWORK(cupy, 5, "cupy.ndarray"); NB_FRAMEWORK(memview, 6, "memoryview"); +NB_FRAMEWORK(array_api, 7, "ArrayLike"); NAMESPACE_BEGIN(device) NB_DEVICE(none, 0); NB_DEVICE(cpu, 1); NB_DEVICE(cuda, 2); diff --git a/src/nb_internals.cpp b/src/nb_internals.cpp index 73ac9bc6..f2b6b342 100644 --- a/src/nb_internals.cpp +++ b/src/nb_internals.cpp @@ -168,6 +168,8 @@ PyTypeObject *nb_meta_cache = nullptr; static const char* interned_c_strs[pyobj_name::string_count] { "value", "copy", + "clone", + "array", "from_dlpack", "__dlpack__", "max_version", diff --git a/src/nb_internals.h b/src/nb_internals.h index f919499e..211dbe46 100644 --- a/src/nb_internals.h +++ b/src/nb_internals.h @@ -426,6 +426,8 @@ struct pyobj_name { enum : int { value_str = 0, // string "value" copy_str, // string "copy" + clone_str, // string "clone" + array_str, // string "array" from_dlpack_str, // string "from_dlpack" dunder_dlpack_str, // string "__dlpack__" max_version_str, // string "max_version" @@ -490,11 +492,12 @@ inline void *inst_ptr(nb_inst *self) { } template struct scoped_pymalloc { - scoped_pymalloc(size_t size = 1) { - ptr = (T *) PyMem_Malloc(size * sizeof(T)); + scoped_pymalloc(size_t size = 1, size_t extra_bytes = 0) { + // Tip: construct objects in the extra bytes using placement new. + ptr = (T *) PyMem_Malloc(size * sizeof(T) + extra_bytes); if (!ptr) fail("scoped_pymalloc(): could not allocate %llu bytes of memory!", - (unsigned long long) size); + (unsigned long long) (size * sizeof(T) + extra_bytes)); } ~scoped_pymalloc() { PyMem_Free(ptr); } T *release() { diff --git a/src/nb_ndarray.cpp b/src/nb_ndarray.cpp index d84177a8..03c4492a 100644 --- a/src/nb_ndarray.cpp +++ b/src/nb_ndarray.cpp @@ -1,39 +1,139 @@ #include #include +#include #include "nb_internals.h" NAMESPACE_BEGIN(NB_NAMESPACE) -NAMESPACE_BEGIN(detail) + +NAMESPACE_BEGIN(dlpack) + +/// Indicates the managed_dltensor_versioned is read only. +static constexpr uint64_t flag_bitmask_read_only = 1UL << 0; + +struct version { + uint32_t major; + uint32_t minor; +}; + +NAMESPACE_END(dlpack) // ======================================================================== +NAMESPACE_BEGIN(detail) + +// DLPack version 0, deprecated Feb 2024, obsoleted March 2025 struct managed_dltensor { dlpack::dltensor dltensor; void *manager_ctx; void (*deleter)(managed_dltensor *); }; -struct ndarray_handle { - managed_dltensor *ndarray; - std::atomic refcount; - PyObject *owner, *self; - bool free_shape; - bool free_strides; - bool call_deleter; - bool ro; +// DLPack version 1, pre-release Feb 2024, release Sep 2024 +struct managed_dltensor_versioned { + dlpack::version version; + void *manager_ctx; + void (*deleter)(managed_dltensor_versioned *); + uint64_t flags = 0UL; + dlpack::dltensor dltensor; }; -static void ndarray_capsule_destructor(PyObject *o) { - error_scope scope; // temporarily save any existing errors - managed_dltensor *mt = - (managed_dltensor *) PyCapsule_GetPointer(o, "dltensor"); +static void mt_from_buffer_delete(managed_dltensor_versioned* self) { + gil_scoped_acquire guard; + Py_buffer *buf = (Py_buffer *) self->manager_ctx; + PyBuffer_Release(buf); + PyMem_Free(buf); + PyMem_Free(self); // This also frees shape and size arrays. +} + +// Forward declaration +struct ndarray_handle; + +template +static void mt_from_handle_delete(MT* self) { + gil_scoped_acquire guard; + ndarray_handle* th = (ndarray_handle *) self->manager_ctx; + PyMem_Free(self); + ndarray_dec_ref(th); +} +template +static void capsule_delete(PyObject *capsule) { + const char* capsule_name; + if constexpr (versioned) + capsule_name = "dltensor_versioned"; + else + capsule_name = "dltensor"; + + using MT = std::conditional_t; + error_scope scope; // temporarily save any existing errors + MT* mt = (MT*) PyCapsule_GetPointer(capsule, capsule_name); if (mt) - ndarray_dec_ref((ndarray_handle *) mt->manager_ctx); + mt->deleter(mt); else PyErr_Clear(); } +// Reference-counted wrapper for versioned or unversioned managed tensors +struct ndarray_handle { + union { + managed_dltensor *mt_unversioned; + managed_dltensor_versioned *mt_versioned; + }; + std::atomic refcount; + PyObject *owner, *self; + bool versioned; // This tags which union member is active. + bool free_strides; // True if we added strides to an imported tensor. + bool call_deleter; // True if tensor was imported, else PyMem_Free(mt). + bool ro; // Whether tensor is read-only. + + PyObject* make_capsule_unversioned() { + PyObject* capsule; + if (!versioned && mt_unversioned->manager_ctx == this) { + capsule = PyCapsule_New(mt_unversioned, "dltensor", + capsule_delete); + } else { + scoped_pymalloc mt; + memcpy(&mt->dltensor, + (versioned) ? &mt_versioned->dltensor + : &mt_unversioned->dltensor, + sizeof(dlpack::dltensor)); + mt->manager_ctx = this; + mt->deleter = mt_from_handle_delete; + capsule = PyCapsule_New(mt.release(), "dltensor", + capsule_delete); + } + check(capsule, "Could not make unversioned capsule"); + refcount++; + return capsule; + } + + PyObject* make_capsule_versioned() { + PyObject* capsule; + if (versioned && mt_versioned->manager_ctx == this) { + capsule = PyCapsule_New(mt_versioned, "dltensor_versioned", + capsule_delete); + } else { + scoped_pymalloc mt; + mt->version = {dlpack::major_version, dlpack::minor_version}; + mt->manager_ctx = this; + mt->deleter = mt_from_handle_delete; + mt->flags = (ro) ? dlpack::flag_bitmask_read_only : 0; + memcpy(&mt->dltensor, + (versioned) ? &mt_versioned->dltensor + : &mt_unversioned->dltensor, + sizeof(dlpack::dltensor)); + capsule = PyCapsule_New(mt.release(), "dltensor_versioned", + capsule_delete); + } + check(capsule, "Could not make versioned capsule"); + refcount++; + return capsule; + } +}; + +// ======================================================================== + static void nb_ndarray_dealloc(PyObject *self) { PyTypeObject *tp = Py_TYPE(self); ndarray_dec_ref(((nb_ndarray *) self)->th); @@ -41,10 +141,10 @@ static void nb_ndarray_dealloc(PyObject *self) { Py_DECREF(tp); } -static int nd_ndarray_tpbuffer(PyObject *exporter, Py_buffer *view, int) { - nb_ndarray *self = (nb_ndarray *) exporter; - - dlpack::dltensor &t = self->th->ndarray->dltensor; +static int nb_ndarray_getbuffer(PyObject *self, Py_buffer *view, int) { + ndarray_handle *th = ((nb_ndarray *) self)->th; + dlpack::dltensor &t = (th->versioned) ? th->mt_versioned->dltensor + : th->mt_unversioned->dltensor; if (t.device.device_type != device::cpu::value) { PyErr_SetString(PyExc_BufferError, "Only CPU-allocated ndarrays can be " @@ -96,84 +196,123 @@ static int nd_ndarray_tpbuffer(PyObject *exporter, Py_buffer *view, int) { } if (!format || t.dtype.lanes != 1) { - PyErr_SetString( - PyExc_BufferError, - "Don't know how to convert DLPack dtype into buffer protocol format!"); + PyErr_SetString(PyExc_BufferError, + "Cannot convert DLPack dtype into buffer protocol format!"); return -1; } - view->format = (char *) format; - view->itemsize = t.dtype.bits / 8; view->buf = (void *) ((uintptr_t) t.data + t.byte_offset); - view->obj = exporter; - Py_INCREF(exporter); + view->obj = self; + Py_INCREF(self); - Py_ssize_t len = view->itemsize; - scoped_pymalloc strides((size_t) t.ndim), - shape((size_t) t.ndim); + scoped_pymalloc shape_and_strides(2 * (size_t) t.ndim); + Py_ssize_t* shape = shape_and_strides.get(); + Py_ssize_t* strides = shape + t.ndim; + const Py_ssize_t itemsize = t.dtype.bits / 8; + Py_ssize_t len = itemsize; for (size_t i = 0; i < (size_t) t.ndim; ++i) { len *= (Py_ssize_t) t.shape[i]; - strides[i] = (Py_ssize_t) t.strides[i] * view->itemsize; shape[i] = (Py_ssize_t) t.shape[i]; + strides[i] = (Py_ssize_t) t.strides[i] * itemsize; } - view->ndim = t.ndim; view->len = len; - view->readonly = self->th->ro; + view->itemsize = itemsize; + view->readonly = th->ro; + view->ndim = t.ndim; + view->format = (char *) format; + view->shape = shape; + view->strides = strides; view->suboffsets = nullptr; - view->internal = nullptr; - view->strides = strides.release(); - view->shape = shape.release(); + view->internal = shape_and_strides.release(); return 0; } static void nb_ndarray_releasebuffer(PyObject *, Py_buffer *view) { - PyMem_Free(view->shape); - PyMem_Free(view->strides); + PyMem_Free(view->internal); } +// This function implements __dlpack__() for a nanobind.nb_ndarray. +static PyObject *nb_ndarray_dlpack(PyObject *self, PyObject *const *args, + Py_ssize_t nargsf, PyObject *kwnames) { + if (PyVectorcall_NARGS(nargsf) != 0) { + PyErr_SetString(PyExc_TypeError, + "__dlpack__() does not accept positional arguments"); + return nullptr; + } + Py_ssize_t nkwargs = (kwnames) ? NB_TUPLE_GET_SIZE(kwnames) : 0; + + long max_major_version = 0; + for (Py_ssize_t i = 0; i < nkwargs; ++i) { + PyObject* key = NB_TUPLE_GET_ITEM(kwnames, i); + if (key == static_pyobjects[pyobj_name::dl_device_str] || + key == static_pyobjects[pyobj_name::copy_str]) + // These keyword arguments are ignored. This branch of the code + // is here to avoid a Python call to RichCompare if these kwargs + // are provided by the caller. + continue; + if (key == static_pyobjects[pyobj_name::max_version_str] || + PyObject_RichCompareBool(key, + static_pyobjects[pyobj_name::max_version_str], Py_EQ) == 1) { + PyObject* value = args[i]; + if (value == Py_None) + break; + if (!PyTuple_Check(value) || NB_TUPLE_GET_SIZE(value) != 2) { + PyErr_SetString(PyExc_TypeError, + "max_version must be None or tuple[int, int]"); + return nullptr; + } + max_major_version = PyLong_AsLong(NB_TUPLE_GET_ITEM(value, 0)); + break; + } + } -static PyObject *nb_ndarray_dlpack(PyObject *self, PyTypeObject *, - PyObject *const *, Py_ssize_t , - PyObject *) { - nb_ndarray *self_nd = (nb_ndarray *) self; - ndarray_handle *th = self_nd->th; + ndarray_handle *th = ((nb_ndarray *) self)->th; + PyObject *capsule; + if (max_major_version >= dlpack::major_version) + capsule = th->make_capsule_versioned(); + else + capsule = th->make_capsule_unversioned(); - PyObject *r = - PyCapsule_New(th->ndarray, "dltensor", ndarray_capsule_destructor); - if (r) - ndarray_inc_ref(th); - return r; + return capsule; } -static PyObject *nb_ndarray_dlpack_device(PyObject *self, PyTypeObject *, - PyObject *const *, Py_ssize_t , - PyObject *) { - nb_ndarray *self_nd = (nb_ndarray *) self; - dlpack::dltensor &t = self_nd->th->ndarray->dltensor; - PyObject *r = PyTuple_New(2); - PyObject *r0 = PyLong_FromLong(t.device.device_type); - PyObject *r1 = PyLong_FromLong(t.device.device_id); - if (!r || !r0 || !r1) { - Py_XDECREF(r); - Py_XDECREF(r0); - Py_XDECREF(r1); - return nullptr; +// This function implements __dlpack_device__() for a nanobind.nb_ndarray. +static PyObject *nb_ndarray_dlpack_device(PyObject *self, PyObject *) { + ndarray_handle *th = ((nb_ndarray *) self)->th; + dlpack::dltensor& t = (th->versioned) + ? th->mt_versioned->dltensor + : th->mt_unversioned->dltensor; + PyObject *r; + if (t.device.device_type == 1 && t.device.device_id == 0) { + r = static_pyobjects[pyobj_name::dl_cpu_tpl]; + Py_INCREF(r); + } else { + r = PyTuple_New(2); + PyObject *r0 = PyLong_FromLong(t.device.device_type); + PyObject *r1 = PyLong_FromLong(t.device.device_id); + if (!r || !r0 || !r1) { + Py_XDECREF(r); + Py_XDECREF(r0); + Py_XDECREF(r1); + return nullptr; + } + NB_TUPLE_SET_ITEM(r, 0, r0); + NB_TUPLE_SET_ITEM(r, 1, r1); } - NB_TUPLE_SET_ITEM(r, 0, r0); - NB_TUPLE_SET_ITEM(r, 1, r1); return r; } -static PyMethodDef nb_ndarray_members[] = { - { "__dlpack__", (PyCFunction) (void *) nb_ndarray_dlpack, METH_FASTCALL | METH_KEYWORDS, nullptr }, - { "__dlpack_device__", (PyCFunction) (void *) nb_ndarray_dlpack_device, METH_FASTCALL | METH_KEYWORDS, nullptr }, +static PyMethodDef nb_ndarray_methods[] = { + { "__dlpack__", (PyCFunction) (void *) nb_ndarray_dlpack, + METH_FASTCALL | METH_KEYWORDS, nullptr }, + { "__dlpack_device__", nb_ndarray_dlpack_device, METH_NOARGS, nullptr }, { nullptr, nullptr, 0, nullptr } }; -static PyTypeObject *nd_ndarray_tp() noexcept { +static PyTypeObject *nb_ndarray_tp() noexcept { nb_internals *internals_ = internals; PyTypeObject *tp = internals_->nb_ndarray.load_acquire(); @@ -185,9 +324,9 @@ static PyTypeObject *nd_ndarray_tp() noexcept { PyType_Slot slots[] = { { Py_tp_dealloc, (void *) nb_ndarray_dealloc }, - { Py_tp_methods, (void *) nb_ndarray_members }, + { Py_tp_methods, (void *) nb_ndarray_methods }, #if PY_VERSION_HEX >= 0x03090000 - { Py_bf_getbuffer, (void *) nd_ndarray_tpbuffer }, + { Py_bf_getbuffer, (void *) nb_ndarray_getbuffer }, { Py_bf_releasebuffer, (void *) nb_ndarray_releasebuffer }, #endif { 0, nullptr } @@ -205,7 +344,7 @@ static PyTypeObject *nd_ndarray_tp() noexcept { check(tp, "nb_ndarray type creation failed!"); #if PY_VERSION_HEX < 0x03090000 - tp->tp_as_buffer->bf_getbuffer = nd_ndarray_tpbuffer; + tp->tp_as_buffer->bf_getbuffer = nb_ndarray_getbuffer; tp->tp_as_buffer->bf_releasebuffer = nb_ndarray_releasebuffer; #endif @@ -215,14 +354,18 @@ static PyTypeObject *nd_ndarray_tp() noexcept { return tp; } -static PyObject *dlpack_from_buffer_protocol(PyObject *o, bool ro) { - scoped_pymalloc view; - scoped_pymalloc mt; +// ======================================================================== + +using mt_unique_ptr_t = std::unique_ptr; +static mt_unique_ptr_t make_mt_from_buffer_protocol(PyObject *o, bool ro) { + mt_unique_ptr_t mt_unique_ptr(nullptr, &mt_from_buffer_delete); + scoped_pymalloc view; if (PyObject_GetBuffer(o, view.get(), ro ? PyBUF_RECORDS_RO : PyBUF_RECORDS)) { PyErr_Clear(); - return nullptr; + return mt_unique_ptr; } char format_c = 'B'; @@ -233,7 +376,7 @@ static PyObject *dlpack_from_buffer_protocol(PyObject *o, bool ro) { bool skip_first = format_c == '@' || format_c == '='; int32_t num = 1; - if(*(uint8_t *) &num == 1) { + if (*(uint8_t *) &num == 1) { if (format_c == '<') skip_first = true; } else { @@ -274,8 +417,7 @@ static PyObject *dlpack_from_buffer_protocol(PyObject *o, bool ro) { case '?': dt.code = (uint8_t) dlpack::dtype_code::Bool; break; - default: - fail = true; + default: fail = true; } if (is_complex) { @@ -289,71 +431,64 @@ static PyObject *dlpack_from_buffer_protocol(PyObject *o, bool ro) { if (fail) { PyBuffer_Release(view.get()); - return nullptr; + return mt_unique_ptr; } - mt->deleter = [](managed_dltensor *mt2) { - gil_scoped_acquire guard; - Py_buffer *buf = (Py_buffer *) mt2->manager_ctx; - PyBuffer_Release(buf); - PyMem_Free(mt2->manager_ctx); - PyMem_Free(mt2->dltensor.shape); - PyMem_Free(mt2->dltensor.strides); - PyMem_Free(mt2); - }; + int32_t ndim = view->ndim; - /* DLPack mandates 256-byte alignment of the 'DLTensor::data' field, but - PyTorch unfortunately ignores the 'byte_offset' value.. :-( */ + static_assert(alignof(managed_dltensor_versioned) >= alignof(int64_t)); + scoped_pymalloc mt(1, 2 * sizeof(int64_t)*ndim); + int64_t* shape = nullptr; + int64_t* strides = nullptr; + if (ndim > 0) { + shape = new ((void*) (mt.get() + 1)) int64_t[2 * ndim]; + strides = shape + ndim; + } + + /* See comments in function ndarray_create(). */ #if 0 - uintptr_t value_int = (uintptr_t) view->buf, - value_rounded = (value_int / 256) * 256; + uintptr_t data_uint = (uintptr_t) view->buf; + void* data_ptr = (void *) (data_uint & ~uintptr_t{255}); + uint64_t data_offset = data_uint & uintptr_t{255}; #else - uintptr_t value_int = (uintptr_t) view->buf, - value_rounded = value_int; + void* data_ptr = view->buf; + constexpr uint64_t data_offset = 0UL; #endif - mt->dltensor.data = (void *) value_rounded; + mt->dltensor.data = data_ptr; mt->dltensor.device = { device::cpu::value, 0 }; - mt->dltensor.ndim = view->ndim; + mt->dltensor.ndim = ndim; mt->dltensor.dtype = dt; - mt->dltensor.byte_offset = value_int - value_rounded; + mt->dltensor.shape = shape; + mt->dltensor.strides = strides; + mt->dltensor.byte_offset = data_offset; - scoped_pymalloc strides((size_t) view->ndim); - scoped_pymalloc shape((size_t) view->ndim); - const int64_t itemsize = static_cast(view->itemsize); - for (size_t i = 0; i < (size_t) view->ndim; ++i) { + const int64_t itemsize = (int64_t) view->itemsize; + for (int32_t i = 0; i < ndim; ++i) { int64_t stride = view->strides[i] / itemsize; if (stride * itemsize != view->strides[i]) { PyBuffer_Release(view.get()); - return nullptr; + return mt_unique_ptr; } strides[i] = stride; shape[i] = (int64_t) view->shape[i]; } + mt->version = {dlpack::major_version, dlpack::minor_version}; mt->manager_ctx = view.release(); - mt->dltensor.shape = shape.release(); - mt->dltensor.strides = strides.release(); - - return PyCapsule_New(mt.release(), "dltensor", [](PyObject *o) { - error_scope scope; // temporarily save any existing errors - managed_dltensor *mt = - (managed_dltensor *) PyCapsule_GetPointer(o, "dltensor"); - if (mt) { - if (mt->deleter) - mt->deleter(mt); - } else { - PyErr_Clear(); - } - }); + mt->deleter = mt_from_buffer_delete; + mt->flags = (ro) ? dlpack::flag_bitmask_read_only : 0; + + mt_unique_ptr.reset(mt.release()); + return mt_unique_ptr; } bool ndarray_check(PyObject *o) noexcept { - if (PyObject_HasAttrString(o, "__dlpack__") || PyObject_CheckBuffer(o)) + if (PyObject_HasAttr(o, static_pyobjects[pyobj_name::dunder_dlpack_str]) || + PyObject_CheckBuffer(o)) return true; PyTypeObject *tp = Py_TYPE(o); - if (tp == &PyCapsule_Type) return true; @@ -378,19 +513,45 @@ bool ndarray_check(PyObject *o) noexcept { } -ndarray_handle *ndarray_import(PyObject *o, const ndarray_config *c, +ndarray_handle *ndarray_import(PyObject *src, const ndarray_config *c, bool convert, cleanup_list *cleanup) noexcept { object capsule; - bool is_pycapsule = PyCapsule_CheckExact(o); + const bool src_is_pycapsule = PyCapsule_CheckExact(src); + mt_unique_ptr_t mt_unique_ptr(nullptr, &mt_from_buffer_delete); - // If this is not a capsule, try calling o.__dlpack__() - if (!is_pycapsule) { - capsule = steal(PyObject_CallMethod(o, "__dlpack__", nullptr)); + if (src_is_pycapsule) { + capsule = borrow(src); + } else { + // Try calling src.__dlpack__() +#if PY_VERSION_HEX < 0x03090000 + capsule = steal(PyObject_CallMethod(src, "__dlpack__", nullptr)); +#else + PyObject* args[] = {src, static_pyobjects[pyobj_name::dl_version_tpl]}; + Py_ssize_t nargsf = 1 | PY_VECTORCALL_ARGUMENTS_OFFSET; + capsule = steal(PyObject_VectorcallMethod( + static_pyobjects[pyobj_name::dunder_dlpack_str], + args, nargsf, + static_pyobjects[pyobj_name::max_version_tpl])); + + // Python array API standard v2023 introduced max_version. + // Try calling src.__dlpack__() without any kwargs. + if (!capsule.is_valid() && PyErr_ExceptionMatches(PyExc_TypeError)) { + PyErr_Clear(); + capsule = steal(PyObject_VectorcallMethod( + static_pyobjects[pyobj_name::dunder_dlpack_str], + args, nargsf, nullptr)); + } +#endif + // Try creating an ndarray via the buffer protocol if (!capsule.is_valid()) { PyErr_Clear(); - PyTypeObject *tp = Py_TYPE(o); + mt_unique_ptr = make_mt_from_buffer_protocol(src, c->ro); + } + // Try the function to_dlpack(), already obsolete in array API v2021 + if (!mt_unique_ptr && !capsule.is_valid()) { + PyTypeObject *tp = Py_TYPE(src); try { const char *module_name = borrow(handle(tp).attr("__module__")).c_str(); @@ -398,59 +559,68 @@ ndarray_handle *ndarray_import(PyObject *o, const ndarray_config *c, object package; if (strncmp(module_name, "tensorflow.", 11) == 0) package = module_::import_("tensorflow.experimental.dlpack"); - else if (strcmp(module_name, "torch") == 0) + else if (strncmp(module_name, "torch", 5) == 0) package = module_::import_("torch.utils.dlpack"); else if (strncmp(module_name, "jaxlib", 6) == 0) package = module_::import_("jax.dlpack"); if (package.is_valid()) - capsule = package.attr("to_dlpack")(handle(o)); + capsule = package.attr("to_dlpack")(handle(src)); } catch (...) { capsule.reset(); } + if (!capsule.is_valid()) + return nullptr; } + } - // Try creating an ndarray via the buffer protocol - if (!capsule.is_valid()) - capsule = steal(dlpack_from_buffer_protocol(o, c->ro)); - - if (!capsule.is_valid()) - return nullptr; + void* mt; // can be versioned or unversioned + bool versioned = true; + if (mt_unique_ptr) { + mt = mt_unique_ptr.get(); } else { - capsule = borrow(o); + // Extract the managed_dltensor{_versioned} pointer from the capsule. + mt = PyCapsule_GetPointer(capsule.ptr(), "dltensor_versioned"); + if (!mt) { + PyErr_Clear(); + versioned = false; + mt = PyCapsule_GetPointer(capsule.ptr(), "dltensor"); + if (!mt) { + PyErr_Clear(); + return nullptr; + } + } } - // Extract the pointer underlying the capsule - void *ptr = PyCapsule_GetPointer(capsule.ptr(), "dltensor"); - if (!ptr) { - PyErr_Clear(); + dlpack::dltensor& t = (versioned) + ? ((managed_dltensor_versioned *) mt)->dltensor + : ((managed_dltensor *) mt)->dltensor; + + uint64_t flags = (versioned) ? ((managed_dltensor_versioned *) mt)->flags + : 0UL; + + // Reject a read-only ndarray if a writable one is required, and + // reject an ndarray not on the required device. + if ((!c->ro && (flags & dlpack::flag_bitmask_read_only)) + || (c->device_type != 0 && t.device.device_type != c->device_type)) { return nullptr; } - // Check if the ndarray satisfies the requirements - dlpack::dltensor &t = ((managed_dltensor *) ptr)->dltensor; - + // Check if the ndarray satisfies the remaining requirements. bool has_dtype = c->dtype != dlpack::dtype(), - has_device_type = c->device_type != 0, has_shape = c->ndim != -1, has_order = c->order != '\0'; - bool pass_dtype = true, pass_device = true, - pass_shape = true, pass_order = true; + bool pass_dtype = true, pass_shape = true, pass_order = true; if (has_dtype) pass_dtype = t.dtype == c->dtype; - if (has_device_type) - pass_device = t.device.device_type == c->device_type; - if (has_shape) { - pass_shape &= c->ndim == t.ndim; - + pass_shape = t.ndim == c->ndim; if (pass_shape) { for (int32_t i = 0; i < c->ndim; ++i) { - if (c->shape[i] != t.shape[i] && - c->shape[i] != -1) { + if (c->shape[i] != -1 && t.shape[i] != c->shape[i]) { pass_shape = false; break; } @@ -499,14 +669,15 @@ ndarray_handle *ndarray_import(PyObject *o, const ndarray_config *c, } } - bool refused_conversion = t.dtype.code == (uint8_t) dlpack::dtype_code::Complex && - has_dtype && - c->dtype.code != (uint8_t) dlpack::dtype_code::Complex; + // Do not convert shape and do not convert complex numbers to non-complex. + convert &= pass_shape & + !(t.dtype.code == (uint8_t) dlpack::dtype_code::Complex + && has_dtype + && c->dtype.code != (uint8_t) dlpack::dtype_code::Complex); - // Support implicit conversion of 'dtype' and order - if (pass_device && pass_shape && (!pass_dtype || !pass_order) && convert && - capsule.ptr() != o && !refused_conversion) { - PyTypeObject *tp = Py_TYPE(o); + // Support implicit conversion of dtype and order. + if (convert && (!pass_dtype || !pass_order) && !src_is_pycapsule) { + PyTypeObject *tp = Py_TYPE(src); str module_name_o = borrow(handle(tp).attr("__module__")); const char *module_name = module_name_o.c_str(); @@ -518,16 +689,24 @@ ndarray_handle *ndarray_import(PyObject *o, const ndarray_config *c, if (dt.lanes != 1) return nullptr; - const char *prefix = nullptr; char dtype[11]; if (dt.code == (uint8_t) dlpack::dtype_code::Bool) { std::strcpy(dtype, "bool"); } else { + const char *prefix = nullptr; switch (dt.code) { - case (uint8_t) dlpack::dtype_code::Int: prefix = "int"; break; - case (uint8_t) dlpack::dtype_code::UInt: prefix = "uint"; break; - case (uint8_t) dlpack::dtype_code::Float: prefix = "float"; break; - case (uint8_t) dlpack::dtype_code::Complex: prefix = "complex"; break; + case (uint8_t) dlpack::dtype_code::Int: + prefix = "int"; + break; + case (uint8_t) dlpack::dtype_code::UInt: + prefix = "uint"; + break; + case (uint8_t) dlpack::dtype_code::Float: + prefix = "float"; + break; + case (uint8_t) dlpack::dtype_code::Complex: + prefix = "complex"; + break; default: return nullptr; } @@ -536,25 +715,24 @@ ndarray_handle *ndarray_import(PyObject *o, const ndarray_config *c, object converted; try { - if (strcmp(module_name, "numpy") == 0 || strcmp(module_name, "cupy") == 0) { - converted = handle(o).attr("astype")(dtype, order); - } else if (strcmp(module_name, "torch") == 0) { - converted = handle(o).attr("to")( - arg("dtype") = module_::import_("torch").attr(dtype)); + if (strncmp(module_name, "numpy", 5) == 0 + || strncmp(module_name, "cupy", 4) == 0) { + converted = handle(src).attr("astype")(dtype, order); + } else if (strncmp(module_name, "torch", 5) == 0) { + module_ torch = module_::import_("torch"); + converted = handle(src).attr("to")(torch.attr(dtype)); if (c->order == 'C') converted = converted.attr("contiguous")(); } else if (strncmp(module_name, "tensorflow.", 11) == 0) { - converted = module_::import_("tensorflow") - .attr("cast")(handle(o), dtype); + module_ tensorflow = module_::import_("tensorflow"); + converted = tensorflow.attr("cast")(handle(src), dtype); } else if (strncmp(module_name, "jaxlib", 6) == 0) { - converted = handle(o).attr("astype")(dtype); + converted = handle(src).attr("astype")(dtype); } } catch (...) { converted.reset(); } - // Potentially try again recursively - if (!converted.is_valid()) { - return nullptr; - } else { + // Potentially try once again, recursively + if (converted.is_valid()) { ndarray_handle *h = ndarray_import(converted.ptr(), c, false, nullptr); if (h && cleanup) @@ -563,27 +741,31 @@ ndarray_handle *ndarray_import(PyObject *o, const ndarray_config *c, } } - if (!pass_dtype || !pass_device || !pass_shape || !pass_order) + if (!pass_dtype || !pass_shape || !pass_order) return nullptr; // Create a reference-counted wrapper scoped_pymalloc result; - result->ndarray = (managed_dltensor *) ptr; + if (versioned) + result->mt_versioned = (managed_dltensor_versioned *) mt; + else + result->mt_unversioned = (managed_dltensor *) mt; + result->refcount = 0; result->owner = nullptr; - result->free_shape = false; + result->versioned = versioned; result->call_deleter = true; result->ro = c->ro; - if (is_pycapsule) { + if (src_is_pycapsule) { result->self = nullptr; } else { - result->self = o; - Py_INCREF(o); + result->self = src; + Py_INCREF(src); } - // Ensure that the strides member is always initialized - if (t.strides) { + // If ndim > 0, ensure that the strides member is initialized. + if (t.strides || t.ndim == 0) { result->free_strides = false; } else { result->free_strides = true; @@ -593,16 +775,19 @@ ndarray_handle *ndarray_import(PyObject *o, const ndarray_config *c, strides[i] = accum; accum *= t.shape[i]; } - t.strides = strides.release(); } - // Mark the dltensor capsule as "consumed" - if (PyCapsule_SetName(capsule.ptr(), "used_dltensor") || - PyCapsule_SetDestructor(capsule.ptr(), nullptr)) - check(false, "nanobind::detail::ndarray_import(): could not mark " - "dltensor capsule as consumed!"); + if (capsule.is_valid()) { + // Mark the dltensor capsule as used, i.e., "consumed". + const char* used_name = (versioned) ? "used_dltensor_versioned" + : "used_dltensor"; + if (PyCapsule_SetName(capsule.ptr(), used_name) || + PyCapsule_SetDestructor(capsule.ptr(), nullptr)) + check(false, "ndarray_import(): could not mark capsule as used"); + } + mt_unique_ptr.release(); return result.release(); } @@ -610,7 +795,8 @@ dlpack::dltensor *ndarray_inc_ref(ndarray_handle *th) noexcept { if (!th) return nullptr; ++th->refcount; - return &th->ndarray->dltensor; + return (th->versioned) ? &th->mt_versioned->dltensor + : &th->mt_unversioned->dltensor; } void ndarray_dec_ref(ndarray_handle *th) noexcept { @@ -625,50 +811,64 @@ void ndarray_dec_ref(ndarray_handle *th) noexcept { Py_XDECREF(th->owner); Py_XDECREF(th->self); - managed_dltensor *mt = th->ndarray; - if (th->free_shape) { - PyMem_Free(mt->dltensor.shape); - mt->dltensor.shape = nullptr; - } - if (th->free_strides) { - PyMem_Free(mt->dltensor.strides); - mt->dltensor.strides = nullptr; - } - if (th->call_deleter) { + if (th->versioned) { + managed_dltensor_versioned *mt = th->mt_versioned; + if (th->free_strides) { + PyMem_Free(mt->dltensor.strides); + mt->dltensor.strides = nullptr; + } + if (th->call_deleter) { + if (mt->deleter) + mt->deleter(mt); + } else { + PyMem_Free(mt); // This also frees shape and size arrays. + } + } else { + managed_dltensor *mt = th->mt_unversioned; + if (th->free_strides) { + PyMem_Free(mt->dltensor.strides); + mt->dltensor.strides = nullptr; + } + assert(th->call_deleter); if (mt->deleter) mt->deleter(mt); - } else { - PyMem_Free(mt); } PyMem_Free(th); } } -ndarray_handle *ndarray_create(void *value, size_t ndim, const size_t *shape_in, +ndarray_handle *ndarray_create(void *data, size_t ndim, const size_t *shape_in, PyObject *owner, const int64_t *strides_in, dlpack::dtype dtype, bool ro, int device_type, int device_id, char order) { - /* DLPack mandates 256-byte alignment of the 'DLTensor::data' field, but - PyTorch unfortunately ignores the 'byte_offset' value.. :-( */ + /* DLPack mandates 256-byte alignment of the 'DLTensor::data' field, + but this requirement is generally ignored. Also, PyTorch has/had + a bug in ignoring byte_offset and assuming it's zero. + It would be wrong to split the 64-bit raw pointer into two pieces, + as disabled below, since the pointer dltensor.data must point to + allocated memory (i.e., memory that can be accessed). + A byte_offset can be used to support array slicing when data is an + opaque device pointer or handle, on which arithmetic is impossible. + However, this function is not slicing the data. + See also: https://github.com/data-apis/array-api/discussions/779 */ #if 0 - uintptr_t value_int = (uintptr_t) value, - value_rounded = (value_int / 256) * 256; + uintptr_t data_uint = (uintptr_t) data; + data = (void *) (data_uint & ~uintptr_t{255}); // upper bits + uint64_t data_offset = data_uint & uintptr_t{255}; // lowest 8 bits #else - uintptr_t value_int = (uintptr_t) value, - value_rounded = value_int; + constexpr uint64_t data_offset = 0UL; #endif if (device_type == 0) device_type = device::cpu::value; - scoped_pymalloc ndarray; - scoped_pymalloc result; - scoped_pymalloc shape(ndim), strides(ndim); - - auto deleter = [](managed_dltensor *mt) { - gil_scoped_acquire guard; - ndarray_handle *th = (ndarray_handle *) mt->manager_ctx; - ndarray_dec_ref(th); - }; + static_assert(alignof(managed_dltensor_versioned) >= alignof(int64_t)); + scoped_pymalloc mt(1, 2 * sizeof(int64_t)*ndim); + int64_t* shape = nullptr; + int64_t* strides = nullptr; + if (ndim > 0) { + shape = new ((void*) (mt.get() + 1)) int64_t[2 * ndim]; + strides = shape + ndim; + } for (size_t i = 0; i < ndim; ++i) shape[i] = (int64_t) shape_in[i]; @@ -689,27 +889,32 @@ ndarray_handle *ndarray_create(void *value, size_t ndim, const size_t *shape_in, prod *= (int64_t) shape_in[i]; } } else { - check(false, "nanobind::detail::ndarray_create(): unknown " - "memory order requested!"); + check(false, "ndarray_create(): unknown memory order requested!"); } } - ndarray->dltensor.data = (void *) value_rounded; - ndarray->dltensor.device.device_type = (int32_t) device_type; - ndarray->dltensor.device.device_id = (int32_t) device_id; - ndarray->dltensor.ndim = (int32_t) ndim; - ndarray->dltensor.dtype = dtype; - ndarray->dltensor.byte_offset = value_int - value_rounded; - ndarray->dltensor.shape = shape.release(); - ndarray->dltensor.strides = strides.release(); - ndarray->manager_ctx = result.get(); - ndarray->deleter = deleter; - result->ndarray = (managed_dltensor *) ndarray.release(); + scoped_pymalloc result; + + mt->version = {dlpack::major_version, dlpack::minor_version}; + mt->manager_ctx = result.get(); + mt->deleter = [](managed_dltensor_versioned *self) { + ndarray_dec_ref((ndarray_handle *) self->manager_ctx); + }; + mt->flags = (ro) ? dlpack::flag_bitmask_read_only : 0; + mt->dltensor.data = data; + mt->dltensor.device.device_type = (int32_t) device_type; + mt->dltensor.device.device_id = (int32_t) device_id; + mt->dltensor.ndim = (int32_t) ndim; + mt->dltensor.dtype = dtype; + mt->dltensor.shape = shape; + mt->dltensor.strides = strides; + mt->dltensor.byte_offset = data_offset; + result->mt_versioned = mt.release(); result->refcount = 0; result->owner = owner; result->self = nullptr; - result->free_shape = true; - result->free_strides = true; + result->versioned = true; + result->free_strides = false; result->call_deleter = false; result->ro = ro; Py_XINCREF(owner); @@ -717,7 +922,7 @@ ndarray_handle *ndarray_create(void *value, size_t ndim, const size_t *shape_in, } PyObject *ndarray_export(ndarray_handle *th, int framework, - rv_policy policy, cleanup_list *cleanup) noexcept { + rv_policy policy, cleanup_list *cleanup) noexcept { if (!th) return none().release().ptr(); @@ -765,57 +970,95 @@ PyObject *ndarray_export(ndarray_handle *th, int framework, object o; if (copy && framework == no_framework::value && th->self) { o = borrow(th->self); - } else if (framework == numpy::value || framework == jax::value || framework == memview::value) { - nb_ndarray *h = PyObject_New(nb_ndarray, nd_ndarray_tp()); + } else if (framework == no_framework::value || + framework == tensorflow::value) { + // Make a new capsule wrapping an unversioned managed_dltensor. + o = steal(th->make_capsule_unversioned()); + } else { + // Make a Python object providing the buffer interface and having + // the two DLPack methods __dlpack__() and __dlpack_device__(). + nb_ndarray *h = PyObject_New(nb_ndarray, nb_ndarray_tp()); if (!h) return nullptr; h->th = th; ndarray_inc_ref(th); o = steal((PyObject *) h); - } else { - o = steal(PyCapsule_New(th->ndarray, "dltensor", - ndarray_capsule_destructor)); - ndarray_inc_ref(th); } - try { - if (framework == numpy::value) { - return module_::import_("numpy") - .attr("array")(o, arg("copy") = copy) - .release() - .ptr(); - } else if (framework == memview::value) { - return PyMemoryView_FromObject(o.ptr()); - } else { - const char *pkg_name; - switch (framework) { - case pytorch::value: pkg_name = "torch.utils.dlpack"; break; - case tensorflow::value: pkg_name = "tensorflow.experimental.dlpack"; break; - case jax::value: pkg_name = "jax.dlpack"; break; - case cupy::value: pkg_name = "cupy"; break; - default: pkg_name = nullptr; - } + if (framework == numpy::value) { + try { +#if PY_VERSION_HEX < 0x03090000 + module_ pkg_mod = module_::import_("numpy"); + return pkg_mod.attr(static_pyobjects[pyobj_name::array_str])( + o, arg("copy") = copy) + .release().ptr(); +#else + PyObject* pkg_mod = module_import("numpy"); + PyObject* args[] = {pkg_mod, o.ptr(), + (copy) ? Py_True : Py_False}; + Py_ssize_t nargsf = 2 | PY_VECTORCALL_ARGUMENTS_OFFSET; + return PyObject_VectorcallMethod( + static_pyobjects[pyobj_name::array_str], args, nargsf, + static_pyobjects[pyobj_name::copy_tpl]); +#endif + } catch (const std::exception &e) { + PyErr_Format(PyExc_TypeError, + "could not export nanobind::ndarray: %s", + e.what()); + return nullptr; + } + } - if (pkg_name) - o = module_::import_(pkg_name).attr("from_dlpack")(o); + try { + const char* pkg_name; + switch (framework) { + case pytorch::value: + pkg_name = "torch.utils.dlpack"; + break; + case tensorflow::value: + pkg_name = "tensorflow.experimental.dlpack"; + break; + case jax::value: + pkg_name = "jax.dlpack"; + break; + case cupy::value: + pkg_name = "cupy"; + break; + case memview::value: + return PyMemoryView_FromObject(o.ptr()); + default: + pkg_name = nullptr; + } + if (pkg_name) { +#if PY_VERSION_HEX < 0x03090000 + o = module_::import_(pkg_name) + .attr(static_pyobjects[pyobj_name::from_dlpack_str])(o); +#else + PyObject* pkg_mod = module_import(pkg_name); + PyObject* args[] = {pkg_mod, o.ptr()}; + Py_ssize_t nargsf = 2 | PY_VECTORCALL_ARGUMENTS_OFFSET; + o = steal(PyObject_VectorcallMethod( + static_pyobjects[pyobj_name::from_dlpack_str], + args, nargsf, nullptr)); +#endif } } catch (const std::exception &e) { PyErr_Format(PyExc_TypeError, - "could not export nb::ndarray: %s", + "could not export nanobind::ndarray: %s", e.what()); return nullptr; } if (copy) { - const char* copy_str = "copy"; + PyObject* copy_function_name = static_pyobjects[pyobj_name::copy_str]; if (framework == pytorch::value) - copy_str = "clone"; + copy_function_name = static_pyobjects[pyobj_name::clone_str]; try { - o = o.attr(copy_str)(); + o = o.attr(copy_function_name)(); } catch (std::exception &e) { PyErr_Format(PyExc_RuntimeError, - "nanobind::detail::ndarray_export(): copy failed: %s", + "copying nanobind::ndarray failed: %s", e.what()); return nullptr; } diff --git a/tests/test_jax.cpp b/tests/test_jax.cpp index 0729d1dd..9f9f597c 100644 --- a/tests/test_jax.cpp +++ b/tests/test_jax.cpp @@ -8,15 +8,18 @@ int destruct_count = 0; NB_MODULE(test_jax_ext, m) { m.def("destruct_count", []() { return destruct_count; }); m.def("ret_jax", []() { - float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 }; + struct alignas(64) Buf { + float f[8]; + }; + Buf *buf = new Buf({ 1, 2, 3, 4, 5, 6, 7, 8 }); size_t shape[2] = { 2, 4 }; - nb::capsule deleter(f, [](void *data) noexcept { + nb::capsule deleter(buf, [](void *p) noexcept { destruct_count++; - delete[] (float *) data; + delete (Buf *) p; }); - return nb::ndarray>(f, 2, shape, + return nb::ndarray>(buf->f, 2, shape, deleter); }); } diff --git a/tests/test_jax.py b/tests/test_jax.py index 7eebd8c5..6802e6df 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -14,7 +14,7 @@ def needs_jax(x): @needs_jax -def test01_constrain_order_jax(): +def test01_constrain_order(): with warnings.catch_warnings(): warnings.simplefilter("ignore") try: @@ -27,7 +27,7 @@ def test01_constrain_order_jax(): @needs_jax -def test02_implicit_conversion_jax(): +def test02_implicit_conversion(): with warnings.catch_warnings(): warnings.simplefilter("ignore") try: @@ -60,5 +60,38 @@ def test03_return_jax(): @needs_jax -def test04_check_jax(): +def test04_check(): assert t.check(jnp.zeros((1))) + + +@needs_jax +def test05_passthrough(): + a = tj.ret_jax() + b = t.passthrough(a) + assert a is b + + a = jnp.array([1, 2, 3]) + b = t.passthrough(a) + assert a is b + + a = None + with pytest.raises(TypeError) as excinfo: + b = t.passthrough(a) + assert 'incompatible function arguments' in str(excinfo.value) + b = t.passthrough_arg_none(a) + assert a is b + + +@needs_jax +def test06_ro_array(): + if (not hasattr(jnp, '__array_api_version__') + or jnp.__array_api_version__ < '2024'): + pytest.skip('jax version is too old') + a = jnp.array([1, 2], dtype=jnp.float32) # JAX arrays are immutable. + assert t.accept_ro(a) == 1 + # If the next line fails, delete it, update the array_api_version above, + # and uncomment the three lines below. + assert t.accept_rw(a) == 1 + # with pytest.raises(TypeError) as excinfo: + # t.accept_rw(a) + # assert 'incompatible function arguments' in str(excinfo.value) diff --git a/tests/test_ndarray.cpp b/tests/test_ndarray.cpp index ac2ea218..f9500c11 100644 --- a/tests/test_ndarray.cpp +++ b/tests/test_ndarray.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -188,6 +189,12 @@ NB_MODULE(test_ndarray_ext, m) { m.def("check_device", [](nb::ndarray) -> const char * { return "cpu"; }); m.def("check_device", [](nb::ndarray) -> const char * { return "cuda"; }); + m.def("initialize", + [](nb::ndarray, nb::device::cpu> &t) { + for (size_t i = 0; i < 10; ++i) + t(i) = (unsigned char) i; + }); + m.def("initialize", [](nb::ndarray, nb::device::cpu> &t) { for (size_t i = 0; i < 10; ++i) @@ -313,6 +320,19 @@ NB_MODULE(test_ndarray_ext, m) { deleter); }); + m.def("ret_array_api", []() { + double *d = new double[8] { 1, 2, 3, 4, 5, 6, 7, 8 }; + size_t shape[2] = { 2, 4 }; + + nb::capsule deleter(d, [](void *data) noexcept { + destruct_count++; + delete[] (double *) data; + }); + + return nb::ndarray>(d, 2, shape, + deleter); + }); + m.def("ret_array_scalar", []() { float* f = new float{ 1.0f }; @@ -351,7 +371,7 @@ NB_MODULE(test_ndarray_ext, m) { destruct_count++; } - float data [10] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + float data[10] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; }; nb::class_(m, "Cls") @@ -507,4 +527,34 @@ NB_MODULE(test_ndarray_ext, m) { nb::class_(m, "Wrapper", nb::type_slots(wrapper_slots)) .def(nb::init>()) .def_rw("value", &Wrapper::value); + + // Example from docs/ndarray.rst in section "Array libraries" + class MyArray { + double* d; + public: + MyArray() { d = new double[5] { 0.0, 1.0, 2.0, 3.0, 4.0 }; } + ~MyArray() { delete[] d; } + double* data() const { return d; } + void mutate() { for (int i = 0; i < 5; ++i) d[i] += 0.5; } + }; + + nb::class_(m, "MyArray") + .def(nb::init<>()) + .def("mutate", &MyArray::mutate) + .def("__dlpack__", [](nb::pointer_and_handle self, + nb::kwargs kwargs) { + using array_api_t = nb::ndarray; + nb::object aa = nb::cast(array_api_t(self.p->data(), {5}), + nb::rv_policy::reference_internal, + self.h); + nb::object max = kwargs.get("max_version", nb::none()); + return aa.attr("__dlpack__")(nb::arg("max_version") = max); + }) + .def("__dlpack_device__", [](nb::handle /*self*/) { + return std::make_pair(nb::device::cpu::value, 0); + }) + .def("array_api", [](const MyArray& self) { + return nb::ndarray(self.data(), {5}); + }, nb::rv_policy::reference_internal); + } diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index 73567d1c..72ee7b93 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -133,8 +133,21 @@ def test04_constrain_shape(): t.pass_float32_shaped(np.zeros((3, 5, 4, 6), dtype=np.float32)) +def test05_bytes(): + a = bytearray(range(10)) + assert t.get_is_valid(a) + assert t.get_shape(a) == [10] + assert t.get_size(a) == 10 + assert t.get_nbytes(a) == 10 + assert t.get_itemsize(a) == 1 + assert t.check_order(a) == 'C' + b = b'hello' # immutable + assert t.get_is_valid(b) + assert t.get_shape(b) == [5] + + @needs_numpy -def test05_constrain_order(): +def test06_constrain_order_numpy(): assert t.check_order(np.zeros((3, 5, 4, 6), order='C')) == 'C' assert t.check_order(np.zeros((3, 5, 4, 6), order='F')) == 'F' assert t.check_order(np.zeros((3, 5, 4, 6), order='C')[:, 2, :, :]) == '?' @@ -160,8 +173,18 @@ def test07_constrain_order_pytorch(): assert t.check_device(torch.zeros(3, 5, device='cuda')) == 'cuda' +def test08_write_bytes_from_cpp(): + a = bytearray(10) + t.initialize(a) + assert a == bytearray(range(10)) + b = b'helloHello' # ten immutable bytes + with pytest.raises(TypeError) as excinfo: + t.initialize(b) + assert 'incompatible function arguments' in str(excinfo.value) + + @needs_numpy -def test09_write_from_cpp(): +def test09_write_numpy_from_cpp(): x = np.zeros(10, dtype=np.float32) t.initialize(x) assert np.all(x == np.arange(10, dtype=np.float32)) @@ -244,7 +267,6 @@ def __dlpack__(self): x = np.from_dlpack(wrapper(capsule)) else: pytest.skip('your version of numpy is too old') - del capsule collect() assert x.shape == (2, 4) @@ -297,6 +319,7 @@ def test17_return_numpy(): dc = t.destruct_count() x = t.ret_numpy() assert x.shape == (2, 4) + assert x.flags.writeable assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]]) del x collect() @@ -335,6 +358,60 @@ def test19_return_memview(): assert t.destruct_count() - dc == 1 +@needs_numpy +def test20_return_array_api(): + collect() + dc = t.destruct_count() + obj = t.ret_array_api() + assert obj.__dlpack_device__() == (1, 0) # (type == CPU, id == 0) + capsule = obj.__dlpack__() + assert 'dltensor' in repr(capsule) + assert 'versioned' not in repr(capsule) + capsule = obj.__dlpack__(max_version=None) + assert 'dltensor' in repr(capsule) + assert 'versioned' not in repr(capsule) + capsule = obj.__dlpack__(max_version=(0, 0)) # (major == 0, minor == 0) + assert 'dltensor' in repr(capsule) + assert 'versioned' not in repr(capsule) + capsule = obj.__dlpack__(max_version=(1, 0)) # (major == 1, minor == 0) + assert 'dltensor_versioned' in repr(capsule) + with pytest.raises(TypeError) as excinfo: + capsule = obj.__dlpack__(0) + assert 'does not accept positional arguments' in str(excinfo.value) + del obj + collect() + assert t.destruct_count() == dc + del capsule + collect() + assert t.destruct_count() - dc == 1 + dc += 1 + + obj = t.ret_array_api() # obj also supports the buffer protocol + mv = memoryview(obj) + assert mv.tolist() == [[1, 2, 3, 4], [5, 6, 7, 8]] + del obj + collect() + assert t.destruct_count() == dc + del mv + collect() + assert t.destruct_count() - dc == 1 + dc += 1 + + if (hasattr(np, '__array_api_version__') and + np.__array_api_version__ >= '2024'): + obj = t.ret_array_api() + x = np.from_dlpack(obj) + del obj + collect() + assert t.destruct_count() == dc + assert x.shape == (2, 4) + assert x.flags.writeable + assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]]) + del x + collect() + assert t.destruct_count() - dc == 1 + + @needs_numpy def test21_return_array_scalar(): collect() @@ -418,6 +495,8 @@ def test26_return_ro(): assert t.ret_numpy_const_ref_f.__doc__ == 'ret_numpy_const_ref_f() -> numpy.ndarray[dtype=float32, shape=(2, 4), order=\'F\', writable=False]' assert x.shape == (2, 4) assert y.shape == (2, 4) + assert not x.flags.writeable + assert not y.flags.writeable assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]]) assert np.all(y == [[1, 3, 5, 7], [2, 4, 6, 8]]) with pytest.raises(ValueError) as excinfo: @@ -428,13 +507,49 @@ def test26_return_ro(): assert 'read-only' in str(excinfo.value) +def test27_python_array(): + import array + a = array.array('d', [0, 0, 0, 3.14159, 0]) + assert t.check(a) + assert t.check_rw_by_value(a) + assert a[1] == 1.414214 + assert t.check_rw_by_value_float64(a) + assert a[2] == 2.718282 + assert a[4] == 16.0 + assert t.check_ro_by_value_ro(a) + assert t.check_ro_by_value_const_float64(a) + + a[1] = 0.1 + a[2] = 0.2 + a[4] = 0.4 + mv = memoryview(a) + assert t.check(mv) + assert t.check_rw_by_value(mv) + assert a[1] == 1.414214 + assert t.check_rw_by_value_float64(mv) + assert a[2] == 2.718282 + assert a[4] == 16.0 + assert t.check_ro_by_value_ro(mv) + assert t.check_ro_by_value_const_float64(mv) + + x = t.passthrough(a) + assert x is a + + +def test28_check_bytearray(): + a = bytearray(b'xyz') + assert t.check(a) + mv = memoryview(a) + assert t.check(mv) + + @needs_numpy -def test27_check_numpy(): +def test29_check_numpy(): assert t.check(np.zeros(1)) @needs_torch -def test28_check_torch(): +def test30_check_torch(): assert t.check(torch.zeros((1))) @@ -597,7 +712,7 @@ def test35_view(): x2 = x1 * (-1+2j) t.fill_view_5(x1) assert np.allclose(x1, x2) - x2 = -x2; + x2 = -x2 t.fill_view_6(x1) assert np.allclose(x1, x2) @@ -659,109 +774,109 @@ def test41_noninteger_stride(): a = np.array([[1, 2, 3, 4, 0, 0], [5, 6, 7, 8, 0, 0]], dtype=np.float32) s = a[:, 0:4] # slice t.pass_float32(s) - assert t.get_stride(s, 0) == 6; - assert t.get_stride(s, 1) == 1; + assert t.get_stride(s, 0) == 6 + assert t.get_stride(s, 1) == 1 try: v = s.view(np.complex64) except: pytest.skip('your version of numpy is too old') t.pass_complex64(v) - assert t.get_stride(v, 0) == 3; - assert t.get_stride(v, 1) == 1; + assert t.get_stride(v, 0) == 3 + assert t.get_stride(v, 1) == 1 a = np.array([[1, 2, 3, 4, 0], [5, 6, 7, 8, 0]], dtype=np.float32) s = a[:, 0:4] # slice t.pass_float32(s) - assert t.get_stride(s, 0) == 5; - assert t.get_stride(s, 1) == 1; + assert t.get_stride(s, 0) == 5 + assert t.get_stride(s, 1) == 1 v = s.view(np.complex64) with pytest.raises(TypeError) as excinfo: t.pass_complex64(v) assert 'incompatible function arguments' in str(excinfo.value) with pytest.raises(TypeError) as excinfo: - t.get_stride(v, 0); + t.get_stride(v, 0) assert 'incompatible function arguments' in str(excinfo.value) @needs_numpy def test42_const_qualifiers_numpy(): a = np.array([0, 0, 0, 3.14159, 0], dtype=np.float64) - assert t.check_rw_by_value(a); - assert a[1] == 1.414214; - assert t.check_rw_by_value_float64(a); - assert a[2] == 2.718282; - assert a[4] == 16.0; - assert t.check_ro_by_value_ro(a); - assert t.check_ro_by_value_const_float64(a); + assert t.check_rw_by_value(a) + assert a[1] == 1.414214 + assert t.check_rw_by_value_float64(a) + assert a[2] == 2.718282 + assert a[4] == 16.0 + assert t.check_ro_by_value_ro(a) + assert t.check_ro_by_value_const_float64(a) a.setflags(write=False) - assert t.check_ro_by_value_ro(a); - assert t.check_ro_by_value_const_float64(a); - assert a[0] == 0.0; - assert a[3] == 3.14159; + assert t.check_ro_by_value_ro(a) + assert t.check_ro_by_value_const_float64(a) + assert a[0] == 0.0 + assert a[3] == 3.14159 a = np.array([0, 0, 0, 3.14159, 0], dtype=np.float64) - assert t.check_rw_by_const_ref(a); - assert a[1] == 1.414214; - assert t.check_rw_by_const_ref_float64(a); - assert a[2] == 2.718282; - assert a[4] == 16.0; - assert t.check_ro_by_const_ref_ro(a); - assert t.check_ro_by_const_ref_const_float64(a); + assert t.check_rw_by_const_ref(a) + assert a[1] == 1.414214 + assert t.check_rw_by_const_ref_float64(a) + assert a[2] == 2.718282 + assert a[4] == 16.0 + assert t.check_ro_by_const_ref_ro(a) + assert t.check_ro_by_const_ref_const_float64(a) a.setflags(write=False) - assert t.check_ro_by_const_ref_ro(a); - assert t.check_ro_by_const_ref_const_float64(a); - assert a[0] == 0.0; - assert a[3] == 3.14159; + assert t.check_ro_by_const_ref_ro(a) + assert t.check_ro_by_const_ref_const_float64(a) + assert a[0] == 0.0 + assert a[3] == 3.14159 a = np.array([0, 0, 0, 3.14159, 0], dtype=np.float64) - assert t.check_rw_by_rvalue_ref(a); - assert a[1] == 1.414214; - assert t.check_rw_by_rvalue_ref_float64(a); - assert a[2] == 2.718282; - assert a[4] == 16.0; - assert t.check_ro_by_rvalue_ref_ro(a); - assert t.check_ro_by_rvalue_ref_const_float64(a); + assert t.check_rw_by_rvalue_ref(a) + assert a[1] == 1.414214 + assert t.check_rw_by_rvalue_ref_float64(a) + assert a[2] == 2.718282 + assert a[4] == 16.0 + assert t.check_ro_by_rvalue_ref_ro(a) + assert t.check_ro_by_rvalue_ref_const_float64(a) a.setflags(write=False) - assert t.check_ro_by_rvalue_ref_ro(a); - assert t.check_ro_by_rvalue_ref_const_float64(a); - assert a[0] == 0.0; - assert a[3] == 3.14159; + assert t.check_ro_by_rvalue_ref_ro(a) + assert t.check_ro_by_rvalue_ref_const_float64(a) + assert a[0] == 0.0 + assert a[3] == 3.14159 @needs_torch def test43_const_qualifiers_pytorch(): a = torch.tensor([0, 0, 0, 3.14159, 0], dtype=torch.float64) - assert t.check_rw_by_value(a); - assert a[1] == 1.414214; - assert t.check_rw_by_value_float64(a); - assert a[2] == 2.718282; - assert a[4] == 16.0; - assert t.check_ro_by_value_ro(a); - assert t.check_ro_by_value_const_float64(a); - assert a[0] == 0.0; - assert a[3] == 3.14159; + assert t.check_rw_by_value(a) + assert a[1] == 1.414214 + assert t.check_rw_by_value_float64(a) + assert a[2] == 2.718282 + assert a[4] == 16.0 + assert t.check_ro_by_value_ro(a) + assert t.check_ro_by_value_const_float64(a) + assert a[0] == 0.0 + assert a[3] == 3.14159 a = torch.tensor([0, 0, 0, 3.14159, 0], dtype=torch.float64) - assert t.check_rw_by_const_ref(a); - assert a[1] == 1.414214; - assert t.check_rw_by_const_ref_float64(a); - assert a[2] == 2.718282; - assert a[4] == 16.0; - assert t.check_ro_by_const_ref_ro(a); - assert t.check_ro_by_const_ref_const_float64(a); - assert a[0] == 0.0; - assert a[3] == 3.14159; + assert t.check_rw_by_const_ref(a) + assert a[1] == 1.414214 + assert t.check_rw_by_const_ref_float64(a) + assert a[2] == 2.718282 + assert a[4] == 16.0 + assert t.check_ro_by_const_ref_ro(a) + assert t.check_ro_by_const_ref_const_float64(a) + assert a[0] == 0.0 + assert a[3] == 3.14159 a = torch.tensor([0, 0, 0, 3.14159, 0], dtype=torch.float64) - assert t.check_rw_by_rvalue_ref(a); - assert a[1] == 1.414214; - assert t.check_rw_by_rvalue_ref_float64(a); - assert a[2] == 2.718282; - assert a[4] == 16.0; - assert t.check_ro_by_rvalue_ref_ro(a); - assert t.check_ro_by_rvalue_ref_const_float64(a); - assert a[0] == 0.0; - assert a[3] == 3.14159; + assert t.check_rw_by_rvalue_ref(a) + assert a[1] == 1.414214 + assert t.check_rw_by_rvalue_ref_float64(a) + assert a[2] == 2.718282 + assert a[4] == 16.0 + assert t.check_ro_by_rvalue_ref_ro(a) + assert t.check_ro_by_rvalue_ref_const_float64(a) + assert a[0] == 0.0 + assert a[3] == 3.14159 @needs_cupy @@ -896,3 +1011,26 @@ def test52_accept_np_both_true_contig(): def test53_issue_930(): wrapper = t.Wrapper(np.ones(3, dtype=np.float32)) assert wrapper.value[0] == 1 + + +@needs_numpy +def test54_docs_example(): + ma = t.MyArray() + aa = ma.array_api() + assert 'versioned' not in repr(aa.__dlpack__()) + assert 'versioned' not in repr(ma.__dlpack__()) + assert 'versioned' in repr(aa.__dlpack__(max_version=(1, 2))) + assert 'versioned' in repr(ma.__dlpack__(max_version=(1, 2))) + assert aa.__dlpack_device__() == (1, 0) + assert ma.__dlpack_device__() == (1, 0) + + if hasattr(np, 'from_dlpack'): + x = np.from_dlpack(aa) + y = np.from_dlpack(ma) + assert np.all(x == [0.0, 1.0, 2.0, 3.0, 4.0]) + assert np.all(y == [0.0, 1.0, 2.0, 3.0, 4.0]) + ma.mutate() + assert np.all(x == [0.5, 1.5, 2.5, 3.5, 4.5]) + assert np.all(y == [0.5, 1.5, 2.5, 3.5, 4.5]) + else: + pytest.skip('your version of numpy is too old') diff --git a/tests/test_ndarray_ext.pyi.ref b/tests/test_ndarray_ext.pyi.ref index 58e9d974..6975da39 100644 --- a/tests/test_ndarray_ext.pyi.ref +++ b/tests/test_ndarray_ext.pyi.ref @@ -81,6 +81,9 @@ def check_device(arg: Annotated[NDArray, dict(device='cpu')], /) -> str: ... @overload def check_device(arg: Annotated[NDArray, dict(device='cuda')], /) -> str: ... +@overload +def initialize(arg: Annotated[NDArray[numpy.uint8], dict(shape=(10), device='cpu')], /) -> None: ... + @overload def initialize(arg: Annotated[NDArray[numpy.float32], dict(shape=(10), device='cpu')], /) -> None: ... @@ -117,6 +120,8 @@ def ret_pytorch() -> Annotated[NDArray[numpy.float32], dict(shape=(2, 4))]: ... def ret_memview() -> memoryview[dtype=float64, shape=(2, 4)]: ... +def ret_array_api() -> ArrayLike[dtype=float64, shape=(2, 4)]: ... + def ret_array_scalar() -> NDArray[numpy.float32]: ... def noop_3d_c_contig(arg: Annotated[NDArray[numpy.float32], dict(shape=(None, None, None), order='C')], /) -> None: ... @@ -197,3 +202,14 @@ class Wrapper: @value.setter def value(self, arg: NDArray[numpy.float32], /) -> None: ... + +class MyArray: + def __init__(self) -> None: ... + + def mutate(self) -> None: ... + + def __dlpack__(self, **kwargs) -> object: ... + + def __dlpack_device__(self) -> tuple[int, int]: ... + + def array_api(self) -> ArrayLike[dtype=float64]: ... diff --git a/tests/test_tensorflow.cpp b/tests/test_tensorflow.cpp index 37088e13..a6c74e52 100644 --- a/tests/test_tensorflow.cpp +++ b/tests/test_tensorflow.cpp @@ -1,10 +1,10 @@ #include #include + namespace nb = nanobind; int destruct_count = 0; - NB_MODULE(test_tensorflow_ext, m) { m.def("destruct_count", []() { return destruct_count; }); m.def("ret_tensorflow", []() { @@ -14,12 +14,14 @@ NB_MODULE(test_tensorflow_ext, m) { Buf *buf = new Buf({ 1, 2, 3, 4, 5, 6, 7, 8 }); size_t shape[2] = { 2, 4 }; - nb::capsule deleter(buf, [](void *data) noexcept { + nb::capsule deleter(buf, [](void *p) noexcept { destruct_count++; - delete (Buf *) data; + delete (Buf *) p; }); - return nb::ndarray>(buf->f, 2, shape, + return nb::ndarray>(buf->f, + 2, + shape, deleter); }); } diff --git a/tests/test_tensorflow.py b/tests/test_tensorflow.py index 5d8c3525..9092344b 100644 --- a/tests/test_tensorflow.py +++ b/tests/test_tensorflow.py @@ -15,7 +15,7 @@ def needs_tensorflow(x): @needs_tensorflow -def test01_constrain_order_tensorflow(): +def test01_constrain_order(): with warnings.catch_warnings(): warnings.simplefilter("ignore") try: @@ -27,7 +27,7 @@ def test01_constrain_order_tensorflow(): @needs_tensorflow -def test02_implicit_conversion_tensorflow(): +def test02_implicit_conversion(): with warnings.catch_warnings(): warnings.simplefilter("ignore") try: @@ -61,5 +61,37 @@ def test03_return_tensorflow(): @needs_tensorflow -def test04_check_tensorflow(): +def test04_check(): assert t.check(tf.zeros((1))) + + +@needs_tensorflow +def test05_passthrough(): + a = ttf.ret_tensorflow() + b = t.passthrough(a) + assert a is b + + a = tf.constant([1, 2, 3]) + b = t.passthrough(a) + assert a is b + + a = None + with pytest.raises(TypeError) as excinfo: + b = t.passthrough(a) + assert 'incompatible function arguments' in str(excinfo.value) + b = t.passthrough_arg_none(a) + assert a is b + + +@needs_tensorflow +def test06_ro_array(): + if tf.__version__ < '2.19': + pytest.skip('tensorflow version is too old') + a = tf.constant([1, 2], dtype=tf.float32) # immutable + assert t.accept_ro(a) == 1 + # If the next line fails, delete it, update the version above, + # and uncomment the three lines below. + assert t.accept_rw(a) == 1 + # with pytest.raises(TypeError) as excinfo: + # t.accept_rw(a) + # assert 'incompatible function arguments' in str(excinfo.value)