Skip to content

Commit ab09248

Browse files
authored
Merge pull request #279 from ManifoldFR/modify-block-test
Test: modify a matrix block through Python subclass
2 parents 05341c9 + 6feaa06 commit ab09248

File tree

5 files changed

+116
-4
lines changed

5 files changed

+116
-4
lines changed

include/eigenpy/numpy-allocator.hpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,15 @@ struct NumpyAllocator<Eigen::Ref<MatType, Options, Stride> > {
7070

7171
if (NumpyType::sharedMemory()) {
7272
const int Scalar_type_code = Register::getTypeCode<Scalar>();
73+
Eigen::DenseIndex inner_stride = mat.innerStride(),
74+
outer_stride = mat.outerStride();
75+
76+
const int elsize = call_PyArray_DescrFromType(Scalar_type_code)->elsize;
77+
npy_intp strides[2] = {elsize * inner_stride, elsize * outer_stride};
78+
7379
PyArrayObject *pyArray = (PyArrayObject *)call_PyArray_New(
7480
getPyArrayType(), static_cast<int>(nd), shape, Scalar_type_code,
75-
mat.data(), NPY_ARRAY_MEMORY_CONTIGUOUS | NPY_ARRAY_ALIGNED);
81+
strides, mat.data(), NPY_ARRAY_MEMORY_CONTIGUOUS | NPY_ARRAY_ALIGNED);
7682

7783
return pyArray;
7884
} else {
@@ -125,9 +131,15 @@ struct NumpyAllocator<const Eigen::Ref<const MatType, Options, Stride> > {
125131

126132
if (NumpyType::sharedMemory()) {
127133
const int Scalar_type_code = Register::getTypeCode<Scalar>();
134+
Eigen::DenseIndex inner_stride = mat.innerStride(),
135+
outer_stride = mat.outerStride();
136+
137+
const int elsize = call_PyArray_DescrFromType(Scalar_type_code)->elsize;
138+
npy_intp strides[2] = {elsize * inner_stride, elsize * outer_stride};
139+
128140
PyArrayObject *pyArray = (PyArrayObject *)call_PyArray_New(
129141
getPyArrayType(), static_cast<int>(nd), shape, Scalar_type_code,
130-
const_cast<Scalar *>(mat.data()),
142+
strides, const_cast<Scalar *>(mat.data()),
131143
NPY_ARRAY_MEMORY_CONTIGUOUS_RO | NPY_ARRAY_ALIGNED);
132144

133145
return pyArray;

include/eigenpy/numpy.hpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020-2021 INRIA
2+
* Copyright 2020-2022 INRIA
33
*/
44

55
#ifndef __eigenpy_numpy_hpp__
@@ -109,6 +109,11 @@ EIGENPY_DLLAPI PyObject* call_PyArray_New(PyTypeObject* py_type_ptr, int nd,
109109
npy_intp* shape, int np_type,
110110
void* data_ptr, int options);
111111

112+
EIGENPY_DLLAPI PyObject* call_PyArray_New(PyTypeObject* py_type_ptr, int nd,
113+
npy_intp* shape, int np_type,
114+
npy_intp* strides, void* data_ptr,
115+
int options);
116+
112117
EIGENPY_DLLAPI int call_PyArray_ObjectType(PyObject*, int);
113118

114119
EIGENPY_DLLAPI PyTypeObject* getPyArrayType();
@@ -143,6 +148,14 @@ inline PyObject* call_PyArray_New(PyTypeObject* py_type_ptr, int nd,
143148
options, NULL);
144149
}
145150

151+
inline PyObject* call_PyArray_New(PyTypeObject* py_type_ptr, int nd,
152+
npy_intp* shape, int np_type,
153+
npy_intp* strides, void* data_ptr,
154+
int options) {
155+
return PyArray_New(py_type_ptr, nd, shape, np_type, strides, data_ptr, 0,
156+
options, NULL);
157+
}
158+
146159
inline int call_PyArray_ObjectType(PyObject* obj, int val) {
147160
return PyArray_ObjectType(obj, val);
148161
}

src/numpy.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020-2021 INRIA
2+
* Copyright 2020-2022 INRIA
33
*/
44

55
#include "eigenpy/numpy.hpp"
@@ -31,6 +31,13 @@ PyObject* call_PyArray_New(PyTypeObject* py_type_ptr, int nd, npy_intp* shape,
3131
options, NULL);
3232
}
3333

34+
PyObject* call_PyArray_New(PyTypeObject* py_type_ptr, int nd, npy_intp* shape,
35+
int np_type, npy_intp* strides, void* data_ptr,
36+
int options) {
37+
return PyArray_New(py_type_ptr, nd, shape, np_type, strides, data_ptr, 0,
38+
options, NULL);
39+
}
40+
3441
int call_PyArray_ObjectType(PyObject* obj, int val) {
3542
return PyArray_ObjectType(obj, val);
3643
}

unittest/eigen_ref.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,28 @@ void setOnes(Eigen::Ref<MatType> mat) {
3030
mat.setOnes();
3131
}
3232

33+
template <typename MatType>
34+
Eigen::Ref<MatType> getBlock(Eigen::Ref<MatType> mat, Eigen::DenseIndex i,
35+
Eigen::DenseIndex j, Eigen::DenseIndex n,
36+
Eigen::DenseIndex m) {
37+
return mat.block(i, j, n, m);
38+
}
39+
40+
template <typename MatType>
41+
Eigen::Ref<MatType> editBlock(Eigen::Ref<MatType> mat, Eigen::DenseIndex i,
42+
Eigen::DenseIndex j, Eigen::DenseIndex n,
43+
Eigen::DenseIndex m) {
44+
typename Eigen::Ref<MatType>::BlockXpr B = mat.block(i, j, n, m);
45+
int k = 0;
46+
for (int i = 0; i < B.rows(); ++i) {
47+
for (int j = 0; j < B.cols(); ++j) {
48+
B(i, j) = k++;
49+
}
50+
}
51+
std::cout << "B:\n" << B << std::endl;
52+
return mat;
53+
}
54+
3355
template <typename MatType>
3456
void fill(Eigen::Ref<MatType> mat, const typename MatType::Scalar& value) {
3557
mat.fill(value);
@@ -52,6 +74,18 @@ const Eigen::Ref<const MatType> asConstRef(Eigen::Ref<MatType> mat) {
5274
return Eigen::Ref<const MatType>(mat);
5375
}
5476

77+
struct modify_block {
78+
MatrixXd J;
79+
modify_block() : J(10, 10) { J.setZero(); }
80+
void modify(int n, int m) { call(J.topLeftCorner(n, m)); }
81+
virtual void call(Eigen::Ref<MatrixXd> mat) = 0;
82+
};
83+
84+
struct modify_wrap : modify_block, bp::wrapper<modify_block> {
85+
modify_wrap() : modify_block() {}
86+
void call(Eigen::Ref<MatrixXd> mat) { this->get_override("call")(mat); }
87+
};
88+
5589
BOOST_PYTHON_MODULE(eigen_ref) {
5690
namespace bp = boost::python;
5791
eigenpy::enableEigenPy();
@@ -77,4 +111,12 @@ BOOST_PYTHON_MODULE(eigen_ref) {
77111
(Eigen::Ref<MatrixXd>(*)(Eigen::Ref<MatrixXd>))asRef<MatrixXd>);
78112
bp::def("asConstRef", (const Eigen::Ref<const MatrixXd> (*)(
79113
Eigen::Ref<MatrixXd>))asConstRef<MatrixXd>);
114+
115+
bp::def("getBlock", &getBlock<MatrixXd>);
116+
bp::def("editBlock", &editBlock<MatrixXd>);
117+
118+
bp::class_<modify_wrap, boost::noncopyable>("modify_block", bp::init<>())
119+
.def_readonly("J", &modify_block::J)
120+
.def("modify", &modify_block::modify)
121+
.def("call", bp::pure_virtual(&modify_wrap::call));
80122
}

unittest/python/test_eigen_ref.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,44 @@ def test(mat):
2424
const_ref = asConstRef(mat)
2525
assert np.all(const_ref == mat)
2626

27+
mat.fill(0.0)
28+
fill(mat[:3, :2], 1.0)
29+
30+
assert np.all(mat[:3, :2] == np.ones((3, 2)))
31+
32+
mat.fill(0.0)
33+
fill(mat[:2, :3], 1.0)
34+
35+
assert np.all(mat[:2, :3] == np.ones((2, 3)))
36+
37+
mat.fill(0.0)
38+
mat_as_C_order = np.array(mat, order="F")
39+
getBlock(mat_as_C_order, 0, 0, 3, 2)[:, :] = 1.0
40+
41+
assert np.all(mat_as_C_order[:3, :2] == np.ones((3, 2)))
42+
43+
mat_as_C_order[:3, :2] = 0.0
44+
mat_copy = mat_as_C_order.copy()
45+
editBlock(mat_as_C_order, 0, 0, 3, 2)
46+
mat_copy[:3, :2] = np.arange(6).reshape(3, 2)
47+
48+
assert np.all(mat_as_C_order == mat_copy)
49+
50+
class ModifyBlockImpl(modify_block):
51+
def __init__(self):
52+
super().__init__()
53+
54+
def call(self, mat):
55+
n, m = mat.shape
56+
mat[:, :] = np.arange(n * m).reshape(n, m)
57+
58+
modify = ModifyBlockImpl()
59+
modify.modify(2, 3)
60+
Jref = np.zeros((10, 10))
61+
Jref[:2, :3] = np.arange(6).reshape(2, 3)
62+
63+
assert np.array_equal(Jref, modify.J)
64+
2765

2866
rows = 10
2967
cols = 30

0 commit comments

Comments
 (0)