Skip to content

Commit 0ce3896

Browse files
committed
Use col_del, row_del, col_join, row_join, column_exchange_dense, row_exchange_dense, dot and cross functions from SymEngine
1 parent 3d90759 commit 0ce3896

File tree

2 files changed

+53
-22
lines changed

2 files changed

+53
-22
lines changed

symengine/lib/symengine.pxd

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,14 @@ cdef extern from "<symengine/matrix.h>" namespace "SymEngine":
788788
const DenseMatrix &x, DenseMatrix &result) nogil except +
789789
void diff "SymEngine::sdiff"(const DenseMatrix &A,
790790
RCP[const Basic] &x, DenseMatrix &result) nogil except +
791+
void row_join(const DenseMatrix &A, const DenseMatrix &B, DenseMatrix &C) nogil
792+
void col_join(const DenseMatrix &A, const DenseMatrix &B, DenseMatrix &C) nogil
793+
void row_del(DenseMatrix &A, unsigned k) nogil
794+
void col_del(DenseMatrix &A, unsigned k) nogil
795+
void row_exchange_dense(DenseMatrix &A, unsigned i, unsigned j) nogil
796+
void column_exchange_dense(DenseMatrix &A, unsigned i, unsigned j) nogil
797+
void dot(const DenseMatrix &A, const DenseMatrix &B, DenseMatrix &C) nogil
798+
void cross(const DenseMatrix &A, const DenseMatrix &B, DenseMatrix &C) nogil
791799
void eye (DenseMatrix &A, int k) nogil
792800
void diag(DenseMatrix &A, vec_basic &v, int k) nogil
793801
void ones(DenseMatrix &A) nogil

symengine/lib/symengine_wrapper.pyx

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2506,31 +2506,38 @@ cdef class DenseMatrixBase(MatrixBase):
25062506
if self.rows != o.rows:
25072507
raise ShapeError("`self` and `rhs` must have the same number of rows.")
25082508
cdef DenseMatrixBase result = self.__class__(self.rows, self.cols + o.cols)
2509-
cdef Basic e_
2510-
for i in range(self.rows):
2511-
for j in range(self.cols):
2512-
e_ = self._get(i, j)
2513-
deref(result.thisptr).set(i, j, e_.thisptr)
2514-
for i in range(o.rows):
2515-
for j in range(o.cols):
2516-
e_ = sympify(o._get(i, j))
2517-
deref(result.thisptr).set(i, j + self.cols, e_.thisptr)
2509+
symengine.row_join(deref(symengine.static_cast_DenseMatrix(self.thisptr)),
2510+
deref(symengine.static_cast_DenseMatrix(o.thisptr)),
2511+
deref(symengine.static_cast_DenseMatrix(result.thisptr)))
25182512
return result
25192513

25202514
def col_join(self, bott):
25212515
cdef DenseMatrixBase o = sympify(bott)
25222516
if self.cols != o.cols:
25232517
raise ShapeError("`self` and `rhs` must have the same number of columns.")
25242518
cdef DenseMatrixBase result = self.__class__(self.rows + o.rows, self.cols)
2525-
cdef Basic e_
2526-
for i in range(self.rows):
2527-
for j in range(self.cols):
2528-
e_ = self._get(i, j)
2529-
deref(result.thisptr).set(i, j, e_.thisptr)
2530-
for i in range(o.rows):
2531-
for j in range(o.cols):
2532-
e_ = sympify(o._get(i, j))
2533-
deref(result.thisptr).set(i + self.rows, j, e_.thisptr)
2519+
symengine.col_join(deref(symengine.static_cast_DenseMatrix(self.thisptr)),
2520+
deref(symengine.static_cast_DenseMatrix(o.thisptr)),
2521+
deref(symengine.static_cast_DenseMatrix(result.thisptr)))
2522+
return result
2523+
2524+
def dot(self, b):
2525+
cdef DenseMatrixBase o = sympify(b)
2526+
cdef DenseMatrixBase result = self.__class__(self.rows, self.cols)
2527+
symengine.dot(deref(symengine.static_cast_DenseMatrix(self.thisptr)),
2528+
deref(symengine.static_cast_DenseMatrix(o.thisptr)),
2529+
deref(symengine.static_cast_DenseMatrix(result.thisptr)))
2530+
return result
2531+
2532+
def cross(self, b):
2533+
cdef DenseMatrixBase o = sympify(b)
2534+
if self.cols * self.rows != 3 or o.cols * o.rows != 3:
2535+
raise ShapeError("Dimensions incorrect for cross product: %s x %s" %
2536+
((self.rows, self.cols), (b.rows, b.cols)))
2537+
cdef DenseMatrixBase result = self.__class__(self.rows, self.cols)
2538+
symengine.cross(deref(symengine.static_cast_DenseMatrix(self.thisptr)),
2539+
deref(symengine.static_cast_DenseMatrix(o.thisptr)),
2540+
deref(symengine.static_cast_DenseMatrix(result.thisptr)))
25342541
return result
25352542

25362543
@property
@@ -2864,17 +2871,33 @@ class DenseMatrixBaseIter(object):
28642871
cdef class MutableDenseMatrix(DenseMatrixBase):
28652872

28662873
def col_swap(self, i, j):
2867-
for k in range(0, self.rows):
2868-
self[k, i], self[k, j] = self[k, j], self[k, i]
2874+
symengine.column_exchange_dense(deref(symengine.static_cast_DenseMatrix(self.thisptr)),
2875+
i, j)
28692876

28702877
def fill(self, value):
28712878
for i in range(self.rows):
28722879
for j in range(self.cols):
28732880
self[i, j] = value
28742881

28752882
def row_swap(self, i, j):
2876-
for k in range(0, self.cols):
2877-
self[i, k], self[j, k] = self[j, k], self[i, k]
2883+
symengine.row_exchange_dense(deref(symengine.static_cast_DenseMatrix(self.thisptr)),
2884+
i, j)
2885+
2886+
def row_del(self, i):
2887+
if i < -self.rows or i >= self.rows:
2888+
raise IndexError("Index out of range: 'i = %s', valid -%s <= i"
2889+
" < %s" % (i, self.rows, self.rows))
2890+
if i < 0:
2891+
i += self.rows
2892+
symengine.row_del(deref(symengine.static_cast_DenseMatrix(self.thisptr)), i)
2893+
2894+
def col_del(self, i):
2895+
if i < -self.cols or i >= self.cols:
2896+
raise IndexError("Index out of range: 'i=%s', valid -%s <= i < %s"
2897+
% (i, self.cols, self.cols))
2898+
if i < 0:
2899+
i += self.cols
2900+
symengine.col_del(deref(symengine.static_cast_DenseMatrix(self.thisptr)), i)
28782901

28792902
def _applyfunc(self, f):
28802903
cdef int nr = self.nrows()

0 commit comments

Comments
 (0)