Skip to content

Commit 7ec95da

Browse files
authored
Merge pull request #242 from jcarpent/devel
Enhance support of ufunc
2 parents 5011116 + 23a2cd3 commit 7ec95da

File tree

3 files changed

+116
-14
lines changed

3 files changed

+116
-14
lines changed

cmake

Submodule cmake updated 1 file

include/eigenpy/ufunc.hpp

Lines changed: 111 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
//
2-
// Copyright (c) 2020 INRIA
2+
// Copyright (c) 2020-2021 INRIA
3+
// code aptapted from https://github.com/numpy/numpy/blob/41977b24ae011a51f64faa75cb524c7350fdedd9/numpy/core/src/umath/_rational_tests.c.src
34
//
45

56
#ifndef __eigenpy_ufunc_hpp__
67
#define __eigenpy_ufunc_hpp__
78

89
#include "eigenpy/register.hpp"
10+
#include "eigenpy/user-type.hpp"
911

1012
namespace eigenpy
1113
{
@@ -18,6 +20,77 @@ namespace eigenpy
1820
#define EIGENPY_NPY_CONST_UFUNC_ARG
1921
#endif
2022

23+
template<typename T>
24+
void matrix_multiply(char **args, npy_intp const *dimensions, npy_intp const *steps)
25+
{
26+
/* pointers to data for input and output arrays */
27+
char *ip1 = args[0];
28+
char *ip2 = args[1];
29+
char *op = args[2];
30+
31+
/* lengths of core dimensions */
32+
npy_intp dm = dimensions[0];
33+
npy_intp dn = dimensions[1];
34+
npy_intp dp = dimensions[2];
35+
36+
/* striding over core dimensions */
37+
npy_intp is1_m = steps[0];
38+
npy_intp is1_n = steps[1];
39+
npy_intp is2_n = steps[2];
40+
npy_intp is2_p = steps[3];
41+
npy_intp os_m = steps[4];
42+
npy_intp os_p = steps[5];
43+
44+
/* core dimensions counters */
45+
npy_intp m, p;
46+
47+
/* calculate dot product for each row/column vector pair */
48+
for (m = 0; m < dm; m++)
49+
{
50+
for (p = 0; p < dp; p++)
51+
{
52+
SpecialMethods<T>::dotfunc(ip1, is1_n, ip2, is2_n, op, dn, NULL);
53+
54+
/* advance to next column of 2nd input array and output array */
55+
ip2 += is2_p;
56+
op += os_p;
57+
}
58+
59+
/* reset to first column of 2nd input array and output array */
60+
ip2 -= is2_p * p;
61+
op -= os_p * p;
62+
63+
/* advance to next row of 1st input array and output array */
64+
ip1 += is1_m;
65+
op += os_m;
66+
}
67+
}
68+
69+
template<typename T>
70+
void gufunc_matrix_multiply(char **args, npy_intp EIGENPY_NPY_CONST_UFUNC_ARG *dimensions,
71+
npy_intp EIGENPY_NPY_CONST_UFUNC_ARG *steps, void *NPY_UNUSED(func))
72+
{
73+
/* outer dimensions counter */
74+
npy_intp N_;
75+
76+
/* length of flattened outer dimensions */
77+
npy_intp dN = dimensions[0];
78+
79+
/* striding over flattened outer dimensions for input and output arrays */
80+
npy_intp s0 = steps[0];
81+
npy_intp s1 = steps[1];
82+
npy_intp s2 = steps[2];
83+
84+
/*
85+
* loop through outer dimensions, performing matrix multiply on
86+
* core dimensions for each loop
87+
*/
88+
for (N_ = 0; N_ < dN; N_++, args[0] += s0, args[1] += s1, args[2] += s2)
89+
{
90+
matrix_multiply<T>(args, dimensions+1, steps+3);
91+
}
92+
}
93+
2194
#define EIGENPY_REGISTER_BINARY_OPERATOR(name,op) \
2295
template<typename T1, typename T2, typename R> \
2396
void binary_op_##name(char** args, EIGENPY_NPY_CONST_UFUNC_ARG npy_intp * dimensions, EIGENPY_NPY_CONST_UFUNC_ARG npy_intp * steps, void * /*data*/) \
@@ -127,7 +200,7 @@ namespace eigenpy
127200
template<typename Scalar>
128201
void registerCommonUfunc()
129202
{
130-
const int code = Register::getTypeCode<Scalar>();
203+
const int type_code = Register::getTypeCode<Scalar>();
131204

132205
PyObject* numpy_str;
133206
#if PY_MAJOR_VERSION >= 3
@@ -140,23 +213,48 @@ namespace eigenpy
140213
Py_DECREF(numpy_str);
141214

142215
import_ufunc();
216+
217+
// Matrix multiply
218+
{
219+
int types[3] = {type_code,type_code,type_code};
220+
221+
std::stringstream ss;
222+
ss << "return result of multiplying two matrices of ";
223+
ss << bp::type_info(typeid(Scalar)).name();
224+
PyUFuncObject* ufunc = (PyUFuncObject*)PyObject_GetAttrString(numpy, "matmul");
225+
if(!ufunc)
226+
{
227+
std::stringstream ss;
228+
ss << "Impossible to define matrix_multiply for given type " << bp::type_info(typeid(Scalar)).name() << std::endl;
229+
eigenpy::Exception(ss.str());
230+
}
231+
if(PyUFunc_RegisterLoopForType((PyUFuncObject*)ufunc, type_code,
232+
&internal::gufunc_matrix_multiply<Scalar>, types, 0) < 0)
233+
{
234+
std::stringstream ss;
235+
ss << "Impossible to register matrix_multiply for given type " << bp::type_info(typeid(Scalar)).name() << std::endl;
236+
eigenpy::Exception(ss.str());
237+
}
238+
239+
Py_DECREF(ufunc);
240+
}
143241

144242
// Binary operators
145-
EIGENPY_REGISTER_BINARY_UFUNC(add,code,Scalar,Scalar,Scalar);
146-
EIGENPY_REGISTER_BINARY_UFUNC(subtract,code,Scalar,Scalar,Scalar);
147-
EIGENPY_REGISTER_BINARY_UFUNC(multiply,code,Scalar,Scalar,Scalar);
148-
EIGENPY_REGISTER_BINARY_UFUNC(divide,code,Scalar,Scalar,Scalar);
243+
EIGENPY_REGISTER_BINARY_UFUNC(add,type_code,Scalar,Scalar,Scalar);
244+
EIGENPY_REGISTER_BINARY_UFUNC(subtract,type_code,Scalar,Scalar,Scalar);
245+
EIGENPY_REGISTER_BINARY_UFUNC(multiply,type_code,Scalar,Scalar,Scalar);
246+
EIGENPY_REGISTER_BINARY_UFUNC(divide,type_code,Scalar,Scalar,Scalar);
149247

150248
// Comparison operators
151-
EIGENPY_REGISTER_BINARY_UFUNC(equal,code,Scalar,Scalar,bool);
152-
EIGENPY_REGISTER_BINARY_UFUNC(not_equal,code,Scalar,Scalar,bool);
153-
EIGENPY_REGISTER_BINARY_UFUNC(greater,code,Scalar,Scalar,bool);
154-
EIGENPY_REGISTER_BINARY_UFUNC(less,code,Scalar,Scalar,bool);
155-
EIGENPY_REGISTER_BINARY_UFUNC(greater_equal,code,Scalar,Scalar,bool);
156-
EIGENPY_REGISTER_BINARY_UFUNC(less_equal,code,Scalar,Scalar,bool);
249+
EIGENPY_REGISTER_BINARY_UFUNC(equal,type_code,Scalar,Scalar,bool);
250+
EIGENPY_REGISTER_BINARY_UFUNC(not_equal,type_code,Scalar,Scalar,bool);
251+
EIGENPY_REGISTER_BINARY_UFUNC(greater,type_code,Scalar,Scalar,bool);
252+
EIGENPY_REGISTER_BINARY_UFUNC(less,type_code,Scalar,Scalar,bool);
253+
EIGENPY_REGISTER_BINARY_UFUNC(greater_equal,type_code,Scalar,Scalar,bool);
254+
EIGENPY_REGISTER_BINARY_UFUNC(less_equal,type_code,Scalar,Scalar,bool);
157255

158256
// Unary operators
159-
EIGENPY_REGISTER_UNARY_UFUNC(negative,code,Scalar,Scalar);
257+
EIGENPY_REGISTER_UNARY_UFUNC(negative,type_code,Scalar,Scalar);
160258

161259
Py_DECREF(numpy);
162260
}

unittest/python/test_user_type.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ def test(dtype):
3232
assert not (mat > mat).all()
3333
assert not (mat < mat).all()
3434

35+
mat2 = mat.dot(mat.T)
36+
if np.__version__ >= '1.17.0':
37+
mat2 = np.matmul(mat,mat.T)
38+
3539
def test_cast(from_dtype,to_dtype):
3640
np.can_cast(from_dtype,to_dtype)
3741

0 commit comments

Comments
 (0)