11//
2- // Copyright (c) 2020 INRIA
2+ // Copyright (c) 2020-2021 INRIA
3+ // code aptapted from https://github.com/numpy/numpy/blob/41977b24ae011a51f64faa75cb524c7350fdedd9/numpy/core/src/umath/_rational_tests.c.src
34//
45
56#ifndef __eigenpy_ufunc_hpp__
67#define __eigenpy_ufunc_hpp__
78
89#include " eigenpy/register.hpp"
10+ #include " eigenpy/user-type.hpp"
911
1012namespace eigenpy
1113{
@@ -18,6 +20,77 @@ namespace eigenpy
1820 #define EIGENPY_NPY_CONST_UFUNC_ARG
1921#endif
2022
23+ template <typename T>
24+ void matrix_multiply (char **args, npy_intp const *dimensions, npy_intp const *steps)
25+ {
26+ /* pointers to data for input and output arrays */
27+ char *ip1 = args[0 ];
28+ char *ip2 = args[1 ];
29+ char *op = args[2 ];
30+
31+ /* lengths of core dimensions */
32+ npy_intp dm = dimensions[0 ];
33+ npy_intp dn = dimensions[1 ];
34+ npy_intp dp = dimensions[2 ];
35+
36+ /* striding over core dimensions */
37+ npy_intp is1_m = steps[0 ];
38+ npy_intp is1_n = steps[1 ];
39+ npy_intp is2_n = steps[2 ];
40+ npy_intp is2_p = steps[3 ];
41+ npy_intp os_m = steps[4 ];
42+ npy_intp os_p = steps[5 ];
43+
44+ /* core dimensions counters */
45+ npy_intp m, p;
46+
47+ /* calculate dot product for each row/column vector pair */
48+ for (m = 0 ; m < dm; m++)
49+ {
50+ for (p = 0 ; p < dp; p++)
51+ {
52+ SpecialMethods<T>::dotfunc (ip1, is1_n, ip2, is2_n, op, dn, NULL );
53+
54+ /* advance to next column of 2nd input array and output array */
55+ ip2 += is2_p;
56+ op += os_p;
57+ }
58+
59+ /* reset to first column of 2nd input array and output array */
60+ ip2 -= is2_p * p;
61+ op -= os_p * p;
62+
63+ /* advance to next row of 1st input array and output array */
64+ ip1 += is1_m;
65+ op += os_m;
66+ }
67+ }
68+
69+ template <typename T>
70+ void gufunc_matrix_multiply (char **args, npy_intp EIGENPY_NPY_CONST_UFUNC_ARG *dimensions,
71+ npy_intp EIGENPY_NPY_CONST_UFUNC_ARG *steps, void *NPY_UNUSED (func))
72+ {
73+ /* outer dimensions counter */
74+ npy_intp N_;
75+
76+ /* length of flattened outer dimensions */
77+ npy_intp dN = dimensions[0 ];
78+
79+ /* striding over flattened outer dimensions for input and output arrays */
80+ npy_intp s0 = steps[0 ];
81+ npy_intp s1 = steps[1 ];
82+ npy_intp s2 = steps[2 ];
83+
84+ /*
85+ * loop through outer dimensions, performing matrix multiply on
86+ * core dimensions for each loop
87+ */
88+ for (N_ = 0 ; N_ < dN; N_++, args[0 ] += s0, args[1 ] += s1, args[2 ] += s2)
89+ {
90+ matrix_multiply<T>(args, dimensions+1 , steps+3 );
91+ }
92+ }
93+
2194#define EIGENPY_REGISTER_BINARY_OPERATOR (name,op ) \
2295 template <typename T1, typename T2, typename R> \
2396 void binary_op_##name(char ** args, EIGENPY_NPY_CONST_UFUNC_ARG npy_intp * dimensions, EIGENPY_NPY_CONST_UFUNC_ARG npy_intp * steps, void * /* data*/ ) \
@@ -127,7 +200,7 @@ namespace eigenpy
127200 template <typename Scalar>
128201 void registerCommonUfunc ()
129202 {
130- const int code = Register::getTypeCode<Scalar>();
203+ const int type_code = Register::getTypeCode<Scalar>();
131204
132205 PyObject* numpy_str;
133206#if PY_MAJOR_VERSION >= 3
@@ -140,23 +213,48 @@ namespace eigenpy
140213 Py_DECREF (numpy_str);
141214
142215 import_ufunc ();
216+
217+ // Matrix multiply
218+ {
219+ int types[3 ] = {type_code,type_code,type_code};
220+
221+ std::stringstream ss;
222+ ss << " return result of multiplying two matrices of " ;
223+ ss << bp::type_info (typeid (Scalar)).name ();
224+ PyUFuncObject* ufunc = (PyUFuncObject*)PyObject_GetAttrString (numpy, " matmul" );
225+ if (!ufunc)
226+ {
227+ std::stringstream ss;
228+ ss << " Impossible to define matrix_multiply for given type " << bp::type_info (typeid (Scalar)).name () << std::endl;
229+ eigenpy::Exception (ss.str ());
230+ }
231+ if (PyUFunc_RegisterLoopForType ((PyUFuncObject*)ufunc, type_code,
232+ &internal::gufunc_matrix_multiply<Scalar>, types, 0 ) < 0 )
233+ {
234+ std::stringstream ss;
235+ ss << " Impossible to register matrix_multiply for given type " << bp::type_info (typeid (Scalar)).name () << std::endl;
236+ eigenpy::Exception (ss.str ());
237+ }
238+
239+ Py_DECREF (ufunc);
240+ }
143241
144242 // Binary operators
145- EIGENPY_REGISTER_BINARY_UFUNC (add,code ,Scalar,Scalar,Scalar);
146- EIGENPY_REGISTER_BINARY_UFUNC (subtract,code ,Scalar,Scalar,Scalar);
147- EIGENPY_REGISTER_BINARY_UFUNC (multiply,code ,Scalar,Scalar,Scalar);
148- EIGENPY_REGISTER_BINARY_UFUNC (divide,code ,Scalar,Scalar,Scalar);
243+ EIGENPY_REGISTER_BINARY_UFUNC (add,type_code ,Scalar,Scalar,Scalar);
244+ EIGENPY_REGISTER_BINARY_UFUNC (subtract,type_code ,Scalar,Scalar,Scalar);
245+ EIGENPY_REGISTER_BINARY_UFUNC (multiply,type_code ,Scalar,Scalar,Scalar);
246+ EIGENPY_REGISTER_BINARY_UFUNC (divide,type_code ,Scalar,Scalar,Scalar);
149247
150248 // Comparison operators
151- EIGENPY_REGISTER_BINARY_UFUNC (equal,code ,Scalar,Scalar,bool );
152- EIGENPY_REGISTER_BINARY_UFUNC (not_equal,code ,Scalar,Scalar,bool );
153- EIGENPY_REGISTER_BINARY_UFUNC (greater,code ,Scalar,Scalar,bool );
154- EIGENPY_REGISTER_BINARY_UFUNC (less,code ,Scalar,Scalar,bool );
155- EIGENPY_REGISTER_BINARY_UFUNC (greater_equal,code ,Scalar,Scalar,bool );
156- EIGENPY_REGISTER_BINARY_UFUNC (less_equal,code ,Scalar,Scalar,bool );
249+ EIGENPY_REGISTER_BINARY_UFUNC (equal,type_code ,Scalar,Scalar,bool );
250+ EIGENPY_REGISTER_BINARY_UFUNC (not_equal,type_code ,Scalar,Scalar,bool );
251+ EIGENPY_REGISTER_BINARY_UFUNC (greater,type_code ,Scalar,Scalar,bool );
252+ EIGENPY_REGISTER_BINARY_UFUNC (less,type_code ,Scalar,Scalar,bool );
253+ EIGENPY_REGISTER_BINARY_UFUNC (greater_equal,type_code ,Scalar,Scalar,bool );
254+ EIGENPY_REGISTER_BINARY_UFUNC (less_equal,type_code ,Scalar,Scalar,bool );
157255
158256 // Unary operators
159- EIGENPY_REGISTER_UNARY_UFUNC (negative,code ,Scalar,Scalar);
257+ EIGENPY_REGISTER_UNARY_UFUNC (negative,type_code ,Scalar,Scalar);
160258
161259 Py_DECREF (numpy);
162260 }
0 commit comments