Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions docs/api_core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2086,6 +2086,44 @@ declarations in generated :ref:`stubs <stubs>`,
Declares a callback that will be invoked when a C++ instance is first
cast into a Python object.

.. cpp:struct:: upcast_hook

.. cpp:function:: upcast_hook(void* (* hook)(PyObject*, const std::type_info*) noexcept)

Allow Python instances of the class being bound to be passed to C++
functions that expect a pointer to a subobject of that class.
Since nanobind only acknowledges at most one base class of each bound type,
the upcast hook can be helpful for providing some minimal emulation of
additional bases.

The hook receives a nanobind instance as its first argument and the
desired subobject type as its second. If it can make the cast, it
returns a pointer to something of the requested type; if not, it
returns nullptr.

**Example:**

.. code-block:: cpp

struct A { int a = 10; };
struct B { int b = 20; };
struct D : A, B { int d = 30; };

nb::class_<A>(m, "A").def_rw("a", &A::a);
auto clsB = nb::class_<B>(m, "B").def_rw("b", &B::b);

auto try_D_to_B = [](PyObject *self_py, const std::type_info *target) noexcept -> void* {
D *self = nb::inst_ptr<D>(self_py);
if (*target == &typeid(B)) {
return static_cast<B*>(self);
}
return nullptr;
};

auto clsD = nb::class_<D, A>(m, "D", nb::upcast_hook(try_D_to_B))
.def_rw("d", &D::d);
clsD.attr("b") = clsB.attr("b");


.. _enum_binding_annotations:

Expand Down
6 changes: 6 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ Upcoming version (TBA)
long-standing inconvenience. (PR `#778
<https://github.com/wjakob/nanobind/pull/778>`__).

- Added the class binding annotation :cpp:class:`nb::upcast_hook()
<upcast_hook>` which allows the bound type to describe how to
extract self-pointers of other types from its instances. This can
be useful as part of a strategy for mimicking multiple inheritance.
(PR `#920 <https://github.com/wjakob/nanobind/pull/920>`__)

* ABI version 16.


Expand Down
9 changes: 4 additions & 5 deletions include/nanobind/nb_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,10 @@ struct type_slots {
const PyType_Slot *value;
};

struct type_slots_callback {
using cb_t = void (*)(const detail::type_init_data *t,
PyType_Slot *&slots, size_t max_slots) noexcept;
type_slots_callback(cb_t callback) : callback(callback) { }
cb_t callback;
struct upcast_hook {
using cb_t = void* (*)(PyObject *, const std::type_info *) noexcept;
upcast_hook(cb_t hook) : hook(hook) { }
cb_t hook;
};

struct sig {
Expand Down
14 changes: 12 additions & 2 deletions include/nanobind/nb_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,13 @@ enum class type_flags : uint32_t {

/// Does the type implement a custom __new__ operator that can take no args
/// (except the type object)?
has_nullary_new = (1 << 17)
has_nullary_new = (1 << 17),

// One more bit available without needing a larger reorganization
/// Does the type provide a upcast_hook?
has_upcast_hook = (1 << 18)

// Reorganization will be needed to add any more flags;
// try splitting type_init_flags into a separate field in type_init_data
};

/// Flags about a type that are only relevant when it is being created.
Expand Down Expand Up @@ -125,6 +129,7 @@ struct type_data {
};
void (*set_self_py)(void *, PyObject *) noexcept;
bool (*keep_shared_from_this_alive)(PyObject *) noexcept;
void* (*upcast_hook)(PyObject *, const std::type_info *) noexcept;
#if defined(Py_LIMITED_API)
uint32_t dictoffset;
uint32_t weaklistoffset;
Expand Down Expand Up @@ -183,6 +188,11 @@ NB_INLINE void type_extra_apply(type_init_data & t, const sig &s) {
t.name = s.value;
}

NB_INLINE void type_extra_apply(type_init_data &t, upcast_hook h) {
t.flags |= (uint32_t) type_flags::has_upcast_hook;
t.upcast_hook = h.hook;
}

template <typename T>
NB_INLINE void type_extra_apply(type_init_data &t, supplement<T>) {
static_assert(std::is_trivially_default_constructible_v<T>,
Expand Down
19 changes: 18 additions & 1 deletion src/nb_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,8 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
is_generic = t->flags & (uint32_t) type_flags::is_generic,
intrusive_ptr = t->flags & (uint32_t) type_flags::intrusive_ptr,
has_shared_from_this = t->flags & (uint32_t) type_flags::has_shared_from_this,
has_signature = t->flags & (uint32_t) type_flags::has_signature;
has_signature = t->flags & (uint32_t) type_flags::has_signature,
has_upcast_hook = t->flags & (uint32_t) type_flags::has_upcast_hook;

const char *t_name = t->name;
if (has_signature)
Expand Down Expand Up @@ -1346,6 +1347,12 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
to->keep_shared_from_this_alive = tb->keep_shared_from_this_alive;
}

if (!has_upcast_hook && tb &&
(tb->flags & (uint32_t) type_flags::has_upcast_hook)) {
to->flags |= (uint32_t) type_flags::has_upcast_hook;
to->upcast_hook = tb->upcast_hook;
}

#if defined(Py_LIMITED_API)
to->vectorcall = type_vectorcall;
#else
Expand Down Expand Up @@ -1551,6 +1558,16 @@ bool nb_type_get(const std::type_info *cpp_type, PyObject *src, uint8_t flags,

return true;
}

// This is a nanobind type but not the right one; try an upcast hook
// if one was provided
if (t->flags & (uint32_t) type_flags::has_upcast_hook) {
void *ptr = t->upcast_hook(src, cpp_type);
if (ptr) {
*out = ptr;
return true;
}
}
}

// Try an implicit conversion as last resort (if possible & requested)
Expand Down
22 changes: 22 additions & 0 deletions tests/test_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,4 +718,26 @@ NB_MODULE(test_classes_ext, m) {
.def_prop_ro_static("x", [](nb::handle /*unused*/) { return 42; });
nb::class_<StaticPropertyOverride2, StaticPropertyOverride>(m, "StaticPropertyOverride2")
.def_prop_ro_static("x", [](nb::handle /*unused*/) { return 43; });

struct MultA { int a = 10; };
struct MultB { int b = 20; };
struct MultD : MultA, MultB { int d = 30; };
struct MultE : MultD { int e = 40; };

nb::class_<MultA>(m, "MultA").def(nb::init<>()).def_rw("a", &MultA::a);
auto clsB = nb::class_<MultB>(m, "MultB").def(nb::init<>()).def_rw("b", &MultB::b);

auto try_D_to_B = [](PyObject *self_py, const std::type_info *target) noexcept -> void* {
MultD *self = nb::inst_ptr<MultD>(self_py);
if (*target == typeid(MultB)) {
return static_cast<MultB*>(self);
}
return nullptr;
};

auto clsD = nb::class_<MultD, MultA>(m, "MultD", nb::upcast_hook(try_D_to_B))
.def(nb::init<>())
.def_rw("d", &MultD::d);
clsD.attr("b") = clsB.attr("b");
nb::class_<MultE, MultD>(m, "MultE").def(nb::init<>()).def_rw("e", &MultE::e);
}
32 changes: 32 additions & 0 deletions tests/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,3 +941,35 @@ def my_init(self):
def test49_static_property_override():
assert t.StaticPropertyOverride.x == 42
assert t.StaticPropertyOverride2.x == 43

def test50_i_cant_believe_its_not_multiple_inheritance(monkeypatch):
objs = [t.MultB(), t.MultD(), t.MultE()]
for i, obj in enumerate(objs):
assert obj.b == 20
obj.b += i
try:
assert obj.d == 30
obj.d += 100 * i
except AttributeError:
if i != 0:
raise

assert objs[0].b == 20
assert objs[1].b == 21
assert objs[2].b == 22
assert objs[1].d == 130
assert objs[2].d == 230

def patched_instancecheck(cls, inst, *, _orig=type(t.MultB).__instancecheck__):
if _orig(t.MultD, inst) and cls is t.MultB:
return True
return _orig(cls, inst)

monkeypatch.setattr(type(t.MultB), "__instancecheck__", patched_instancecheck)
assert isinstance(objs[0], t.MultB)
assert not isinstance(objs[0], t.MultD)
assert isinstance(objs[1], t.MultB)
assert isinstance(objs[1], t.MultD)
assert isinstance(objs[2], t.MultB)
assert isinstance(objs[2], t.MultD)
assert isinstance(objs[2], t.MultE)
42 changes: 42 additions & 0 deletions tests/test_classes_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,48 @@ class MonkeyPatchable:
@value.setter
def value(self, arg: int, /) -> None: ...

class MultA:
def __init__(self) -> None: ...

@property
def a(self) -> int: ...

@a.setter
def a(self, arg: int, /) -> None: ...

class MultB:
def __init__(self) -> None: ...

@property
def b(self) -> int: ...

@b.setter
def b(self, arg: int, /) -> None: ...

class MultD(MultA):
def __init__(self) -> None: ...

@property
def d(self) -> int: ...

@d.setter
def d(self, arg: int, /) -> None: ...

@property
def b(self) -> int: ...

@b.setter
def b(self, arg: int, /) -> None: ...

class MultE(MultD):
def __init__(self) -> None: ...

@property
def e(self) -> int: ...

@e.setter
def e(self, arg: int, /) -> None: ...

class MyClass:
def __init__(self) -> None: ...

Expand Down