Skip to content

Commit b675e9e

Browse files
committed
ufunc: use Eigen algebra to compute dotfunc
1 parent 06e4150 commit b675e9e

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

include/eigenpy/user-type.hpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -199,18 +199,15 @@ namespace eigenpy
199199
void * op, npy_intp n, void * /*arr*/)
200200
{
201201
// std::cout << "dotfunc" << std::endl;
202-
T res(0);
203-
char *ip0 = (char*)ip0_, *ip1 = (char*)ip1_;
204-
npy_intp i;
205-
for(i = 0; i < n; i++)
206-
{
207-
208-
res += *static_cast<T*>(static_cast<void*>(ip0))
209-
* *static_cast<T*>(static_cast<void*>(ip1));
210-
ip0 += is0;
211-
ip1 += is1;
212-
}
213-
*static_cast<T*>(op) = res;
202+
typedef Eigen::Matrix<T,Eigen::Dynamic,1> VectorT;
203+
typedef Eigen::InnerStride<Eigen::Dynamic> InputStride;
204+
typedef const Eigen::Map<const VectorT,0,InputStride> ConstMapType;
205+
206+
ConstMapType
207+
v0(static_cast<T*>(ip0_),n,InputStride(is0/sizeof(T))),
208+
v1(static_cast<T*>(ip1_),n,InputStride(is1/sizeof(T)));
209+
210+
*static_cast<T*>(op) = v0.dot(v1);
214211
}
215212

216213

unittest/python/test_user_type.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
def test(dtype):
99
mat = np.ones((rows,cols),dtype=dtype)
10+
mat = np.random.rand(rows,cols).astype(dtype)
1011
mat_copy = mat.copy()
1112
assert (mat == mat_copy).all()
1213
assert not (mat != mat_copy).all()
@@ -33,8 +34,11 @@ def test(dtype):
3334
assert not (mat < mat).all()
3435

3536
mat2 = mat.dot(mat.T)
37+
mat2_ref = mat.astype(np.double).dot(mat.T.astype(np.double))
38+
assert np.isclose(mat2.astype(np.double),mat2_ref).all()
3639
if np.__version__ >= '1.17.0':
3740
mat2 = np.matmul(mat,mat.T)
41+
assert np.isclose(mat2.astype(np.double),mat2_ref).all()
3842

3943
def test_cast(from_dtype,to_dtype):
4044
np.can_cast(from_dtype,to_dtype)

unittest/user_type.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ BOOST_PYTHON_MODULE(user_type)
200200
eigenpy::registerCast<int32_t,DoubleType>(true);
201201
eigenpy::registerCast<DoubleType,int64_t>(false);
202202
eigenpy::registerCast<int64_t,DoubleType>(true);
203+
eigenpy::registerCast<FloatType,double>(true);
204+
eigenpy::registerCast<double,FloatType>(false);
203205
eigenpy::registerCast<FloatType,int64_t>(false);
204206
eigenpy::registerCast<int64_t,FloatType>(true);
205207

0 commit comments

Comments
 (0)