Skip to content

Commit c29edd7

Browse files
committed
unittest: test where subclass modifies input C++ matrix
* test fails, problem in stride ? only first column modified
1 parent cdd9c6a commit c29edd7

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

unittest/eigen_ref.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,27 @@ const Eigen::Ref<const MatType> asConstRef(Eigen::Ref<MatType> mat) {
5252
return Eigen::Ref<const MatType>(mat);
5353
}
5454

55+
struct modify_block
56+
{
57+
MatrixXd J;
58+
modify_block() : J(10, 10) { J.setZero(); }
59+
void modify(int n, int m)
60+
{
61+
call(J.topLeftCorner(n, m));
62+
}
63+
virtual void call(Eigen::Ref<MatrixXd> mat) = 0;
64+
};
65+
66+
struct modify_wrap : modify_block, bp::wrapper<modify_block>
67+
{
68+
modify_wrap() : modify_block() {}
69+
void call(Eigen::Ref<MatrixXd> mat)
70+
{
71+
this->get_override("call")(mat);
72+
}
73+
};
74+
75+
5576
BOOST_PYTHON_MODULE(eigen_ref) {
5677
namespace bp = boost::python;
5778
eigenpy::enableEigenPy();
@@ -77,4 +98,11 @@ BOOST_PYTHON_MODULE(eigen_ref) {
7798
(Eigen::Ref<MatrixXd>(*)(Eigen::Ref<MatrixXd>))asRef<MatrixXd>);
7899
bp::def("asConstRef", (const Eigen::Ref<const MatrixXd> (*)(
79100
Eigen::Ref<MatrixXd>))asConstRef<MatrixXd>);
101+
102+
bp::class_<modify_wrap, boost::noncopyable>("modify_block", bp::init<>())
103+
.def_readonly("J", &modify_block::J)
104+
.def("modify", &modify_block::modify)
105+
.def("call", bp::pure_virtual(&modify_wrap::call))
106+
;
107+
80108
}

unittest/python/test_eigen_ref.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,23 @@ def test(mat):
2424
const_ref = asConstRef(mat)
2525
assert np.all(const_ref == mat)
2626

27+
class ModifyBlockImpl(modify_block):
28+
def __init__(self):
29+
super().__init__()
30+
31+
def call(self, mat):
32+
mat[:, :] = 1.
33+
34+
modify = ModifyBlockImpl()
35+
print("Field J init:\n{}".format(modify.J))
36+
modify.modify(2, 3)
37+
print("Field J after:\n{}".format(modify.J))
38+
Jref = np.zeros((10, 10))
39+
Jref[:2, :3] = 1.
40+
print("Should be:\n{}".format(Jref))
41+
42+
assert np.array_equal(Jref, modify.J)
43+
2744

2845
rows = 10
2946
cols = 30

0 commit comments

Comments
 (0)