Skip to content

Commit bbb129d

Browse files
committed
test: add more test for Eigen::Ref
1 parent cd36103 commit bbb129d

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

unittest/eigen_ref.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ void setOnes(Eigen::Ref<MatType> mat) {
3030
mat.setOnes();
3131
}
3232

33+
template <typename MatType>
34+
Eigen::Ref<MatType> getBlock(Eigen::Ref<MatType> mat, Eigen::DenseIndex i,
35+
Eigen::DenseIndex j, Eigen::DenseIndex n,
36+
Eigen::DenseIndex m) {
37+
return mat.block(i, j, n, m);
38+
}
39+
3340
template <typename MatType>
3441
void fill(Eigen::Ref<MatType> mat, const typename MatType::Scalar& value) {
3542
mat.fill(value);
@@ -90,6 +97,8 @@ BOOST_PYTHON_MODULE(eigen_ref) {
9097
bp::def("asConstRef", (const Eigen::Ref<const MatrixXd> (*)(
9198
Eigen::Ref<MatrixXd>))asConstRef<MatrixXd>);
9299

100+
bp::def("getBlock", &getBlock<MatrixXd>);
101+
93102
bp::class_<modify_wrap, boost::noncopyable>("modify_block", bp::init<>())
94103
.def_readonly("J", &modify_block::J)
95104
.def("modify", &modify_block::modify)

unittest/python/test_eigen_ref.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,25 @@ def test(mat):
2222
assert np.all(ref == mat)
2323

2424
const_ref = asConstRef(mat)
25+
# import pdb; pdb.set_trace()
2526
assert np.all(const_ref == mat)
2627

28+
mat.fill(0.0)
29+
fill(mat[:3, :2], 1.0)
30+
31+
assert np.all(mat[:3, :2] == np.ones((3, 2)))
32+
33+
mat.fill(0.0)
34+
fill(mat[:2, :3], 1.0)
35+
36+
assert np.all(mat[:2, :3] == np.ones((2, 3)))
37+
38+
mat.fill(0.0)
39+
mat_as_C_order = np.array(mat, order="F")
40+
getBlock(mat_as_C_order, 0, 0, 3, 2)[:, :] = 1.0
41+
42+
assert np.all(mat_as_C_order[:3, :2] == np.ones((3, 2)))
43+
2744
class ModifyBlockImpl(modify_block):
2845
def __init__(self):
2946
super().__init__()
@@ -32,12 +49,9 @@ def call(self, mat):
3249
mat[:, :] = 1.0
3350

3451
modify = ModifyBlockImpl()
35-
print("Field J init:\n{}".format(modify.J))
3652
modify.modify(2, 3)
37-
print("Field J after:\n{}".format(modify.J))
3853
Jref = np.zeros((10, 10))
3954
Jref[:2, :3] = 1.0
40-
print("Should be:\n{}".format(Jref))
4155

4256
assert np.array_equal(Jref, modify.J)
4357

0 commit comments

Comments
 (0)