Skip to content

Commit 43022a1

Browse files
committed
Add example showing how to bind virtual classes, passing to overrides by reference, registering shared_ptr, and overriding factories and data classes from Python
1 parent 2dbad07 commit 43022a1

File tree

3 files changed

+233
-1
lines changed

3 files changed

+233
-1
lines changed

unittest/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ if(CMAKE_CXX_STANDARD GREATER 14 AND CMAKE_CXX_STANDARD LESS 98)
6262
config_bind_optional(std "std::optional")
6363
endif()
6464

65+
add_lib_unit_test(bind_virtual_factory)
66+
6567
add_python_unit_test("py-matrix" "unittest/python/test_matrix.py" "unittest")
6668

6769
add_python_unit_test("py-tensor" "unittest/python/test_tensor.py" "unittest")
@@ -118,4 +120,8 @@ set_tests_properties("py-std-vector" PROPERTIES DEPENDS ${PYWRAP})
118120

119121
add_python_unit_test("py-user-struct" "unittest/python/test_user_struct.py"
120122
"python;unittest")
121-
set_tests_properties("py-std-vector" PROPERTIES DEPENDS ${PYWRAP})
123+
set_tests_properties("py-user-struct" PROPERTIES DEPENDS ${PYWRAP})
124+
125+
add_python_unit_test("py-bind-virtual" "unittest/python/test_bind_virtual.py"
126+
"python;unittest")
127+
set_tests_properties("py-bind-virtual" PROPERTIES DEPENDS ${PYWRAP})

unittest/bind_virtual_factory.cpp

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/// Copyright 2023 LAAS-CNRS, INRIA
2+
#include <eigenpy/eigenpy.hpp>
3+
4+
using std::shared_ptr;
5+
namespace bp = boost::python;
6+
7+
// fwd declaration
8+
struct MyVirtualData;
9+
10+
/// A virtual class with two pure virtual functions taking different signatures,
11+
/// and a polymorphic factory function.
12+
struct MyVirtualClass {
13+
MyVirtualClass() {}
14+
virtual ~MyVirtualClass() {}
15+
16+
// polymorphic fn taking arg by shared_ptr
17+
virtual void doSomethingPtr(shared_ptr<MyVirtualData> const &data) const = 0;
18+
// polymorphic fn taking arg by reference
19+
virtual void doSomethingRef(MyVirtualData &data) const = 0;
20+
21+
virtual shared_ptr<MyVirtualData> createData() const {
22+
return std::make_shared<MyVirtualData>(*this);
23+
}
24+
};
25+
26+
struct MyVirtualData {
27+
MyVirtualData(MyVirtualClass const&) {}
28+
virtual ~MyVirtualData() {} // virtual dtor to mark class as polymorphic
29+
};
30+
31+
shared_ptr<MyVirtualData> callDoSomethingPtr(const MyVirtualClass &obj) {
32+
auto d = obj.createData();
33+
printf("Created MyVirtualData with address %p\n", (void *)d.get());
34+
obj.doSomethingPtr(d);
35+
return d;
36+
}
37+
38+
shared_ptr<MyVirtualData> callDoSomethingRef(const MyVirtualClass &obj) {
39+
auto d = obj.createData();
40+
printf("Created MyVirtualData with address %p\n", (void *)d.get());
41+
obj.doSomethingRef(*d);
42+
return d;
43+
}
44+
45+
void throw_virtual_not_implemented_error() {
46+
throw std::runtime_error("Called C++ virtual function.");
47+
}
48+
49+
/// Wrapper classes
50+
struct VirtualClassWrapper : MyVirtualClass, bp::wrapper<MyVirtualClass> {
51+
void doSomethingPtr(shared_ptr<MyVirtualData> const &data) const override {
52+
if (bp::override fo = this->get_override("doSomethingPtr")) {
53+
/// shared_ptr HAS to be passed by value.
54+
/// Boost.Python's argument converter has the wrong behaviour for
55+
/// reference_wrapper<shared_ptr<T>>, so boost::ref(data) does not work.
56+
fo(data);
57+
return;
58+
}
59+
throw_virtual_not_implemented_error();
60+
}
61+
62+
/// The data object is passed by mutable reference to this function,
63+
/// and wrapped in a @c boost::reference_wrapper when passed to the override.
64+
/// Otherwise, Boost.Python's argument converter will convert to Python by
65+
/// value and create a copy.
66+
void doSomethingRef(MyVirtualData &data) const override {
67+
if (bp::override fo = this->get_override("doSomethingRef")) {
68+
fo(boost::ref(data));
69+
return;
70+
}
71+
throw_virtual_not_implemented_error();
72+
}
73+
74+
shared_ptr<MyVirtualData> createData() const override {
75+
if (bp::override fo = this->get_override("createData")) return fo();
76+
return default_createData();
77+
}
78+
79+
shared_ptr<MyVirtualData> default_createData() const {
80+
return MyVirtualClass::createData();
81+
}
82+
};
83+
84+
/// This "trampoline class" does nothing but is ABSOLUTELY required to ensure
85+
/// downcasting works properly with non-smart ptr signatures. Otherwise,
86+
/// there is no handle to the original Python object ( @c PyObject *).
87+
/// Every single polymorphic type exposed to Python should be exposed through such a trampoline.
88+
/// Users can also create their own wrapper classes by taking inspiration from boost::python::wrapper<T>.
89+
struct DataWrapper : MyVirtualData, bp::wrapper<MyVirtualData> {
90+
/// we have to use-declare non-defaulted constructors
91+
/// (see https://en.cppreference.com/w/cpp/language/default_constructor)
92+
/// or define them manually.
93+
using MyVirtualData::MyVirtualData;
94+
};
95+
96+
/// Take and return a const reference
97+
const MyVirtualData &iden_ref(const MyVirtualData &d) {
98+
// try cast to holder
99+
return d;
100+
}
101+
102+
/// Take a shared_ptr (by const reference or value, doesn't matter), return by const reference
103+
const MyVirtualData &iden_shared(const shared_ptr<MyVirtualData> &d) {
104+
// get boost.python's custom deleter
105+
// boost.python hides the handle to the original object in there
106+
// dter being nonzero indicates shared_ptr was wrapped by Boost.Python
107+
auto *dter = std::get_deleter<bp::converter::shared_ptr_deleter>(d);
108+
if (dter != 0) printf("> input shared_ptr has a deleter\n");
109+
return *d;
110+
}
111+
112+
/// Take and return a shared_ptr
113+
shared_ptr<MyVirtualData> copy_shared(const shared_ptr<MyVirtualData> &d) {
114+
auto *dter = std::get_deleter<bp::converter::shared_ptr_deleter>(d);
115+
if (dter != 0) printf("> input shared_ptr has a deleter\n");
116+
return d;
117+
}
118+
119+
BOOST_PYTHON_MODULE(bind_virtual_factory) {
120+
assert(std::is_polymorphic<MyVirtualClass>::value &&
121+
"MyVirtualClass should be polymorphic!");
122+
assert(std::is_polymorphic<MyVirtualData>::value &&
123+
"MyVirtualData should be polymorphic!");
124+
125+
bp::class_<VirtualClassWrapper, boost::noncopyable>(
126+
"MyVirtualClass", bp::init<>(bp::args("self")))
127+
.def("doSomething", bp::pure_virtual(&MyVirtualClass::doSomethingPtr),
128+
bp::args("self", "data"))
129+
.def("doSomethingRef", bp::pure_virtual(&MyVirtualClass::doSomethingRef),
130+
bp::args("self", "data"))
131+
.def("createData", &MyVirtualClass::createData,
132+
&VirtualClassWrapper::default_createData, bp::args("self"));
133+
134+
bp::register_ptr_to_python<shared_ptr<MyVirtualData> >();
135+
/// Trampoline used as 1st argument
136+
/// otherwise if passed as "HeldType", we need to define
137+
/// the constructor and call initializer manually.
138+
bp::class_<DataWrapper, boost::noncopyable>("MyVirtualData", bp::no_init)
139+
.def(bp::init<MyVirtualClass const&>(bp::args("self", "model")));
140+
141+
bp::def("callDoSomethingPtr", callDoSomethingPtr, bp::args("obj"));
142+
bp::def("callDoSomethingRef", callDoSomethingRef, bp::args("obj"));
143+
144+
bp::def("iden_ref", iden_ref, bp::return_internal_reference<>());
145+
bp::def("iden_shared", iden_shared, bp::return_internal_reference<>());
146+
bp::def("copy_shared", copy_shared);
147+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import bind_virtual_factory as bvf
2+
3+
4+
class ImplClass(bvf.MyVirtualClass):
5+
def __init__(self):
6+
self.val = 42
7+
super().__init__()
8+
9+
def createData(self):
10+
return ImplData(self)
11+
12+
# override MyVirtualClass::doSomethingPtr(shared_ptr data)
13+
def doSomethingPtr(self, data):
14+
print("Hello from doSomething!")
15+
assert isinstance(data, ImplData)
16+
print("Data value:", data.value)
17+
data.value += 1
18+
19+
# override MyVirtualClass::doSomethingPtr(data&)
20+
def doSomethingRef(self, data):
21+
print("Hello from doSomethingRef!")
22+
print(type(data))
23+
assert isinstance(data, ImplData)
24+
print("Data value:", data.value)
25+
data.value += 1
26+
27+
28+
class ImplData(bvf.MyVirtualData):
29+
def __init__(self, c):
30+
super().__init__(c) # parent virtual class requires arg
31+
self.value = c.val
32+
33+
34+
def test_instantiate_child():
35+
obj = ImplClass()
36+
data = obj.createData()
37+
print(data)
38+
39+
40+
def test_call_do_something_ptr():
41+
obj = ImplClass()
42+
print("Calling doSomething (by ptr)")
43+
d1 = bvf.callDoSomethingPtr(obj)
44+
print("Output data.value:", d1.value)
45+
46+
47+
def test_call_do_something_ref():
48+
obj = ImplClass()
49+
print("Ref variant:")
50+
d2 = bvf.callDoSomethingRef(obj)
51+
print(d2.value)
52+
print("-----")
53+
54+
55+
def test_iden_fns():
56+
obj = ImplClass()
57+
d = obj.createData()
58+
print(d, type(d))
59+
60+
# take and return const T&
61+
d1 = bvf.iden_ref(d)
62+
print(d1, type(d1))
63+
assert isinstance(d1, ImplData)
64+
65+
# take a shared_ptr, return const T&
66+
d2 = bvf.iden_shared(d)
67+
assert isinstance(d2, ImplData)
68+
print(d2, type(d2))
69+
70+
print("copy shared ptr -> py -> cpp")
71+
d3 = bvf.copy_shared(d)
72+
assert isinstance(d3, ImplData)
73+
print(d3, type(d3))
74+
75+
76+
test_instantiate_child()
77+
test_call_do_something_ptr()
78+
test_call_do_something_ref()
79+
test_iden_fns()

0 commit comments

Comments
 (0)