Skip to content

Commit ac2b834

Browse files
committed
Improvements
1 parent 0ce3896 commit ac2b834

File tree

4 files changed

+178
-45
lines changed

4 files changed

+178
-45
lines changed

symengine/lib/symengine.pxd

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,12 @@ cdef extern from "<symengine/matrix.h>" namespace "SymEngine":
769769
DenseMatrix(unsigned i, unsigned j) nogil
770770
DenseMatrix(unsigned i, unsigned j, const vec_basic &v) nogil
771771
void resize(unsigned i, unsigned j) nogil
772+
void row_join(const DenseMatrix &B) nogil
773+
void col_join(const DenseMatrix &B) nogil
774+
void row_insert(const DenseMatrix &B, unsigned pos) nogil
775+
void col_insert(const DenseMatrix &B, unsigned pos) nogil
776+
void row_del(unsigned k) nogil
777+
void col_del(unsigned k) nogil
772778

773779
bool is_a_DenseMatrix "SymEngine::is_a<SymEngine::DenseMatrix>"(const MatrixBase &b) nogil
774780
DenseMatrix* static_cast_DenseMatrix "static_cast<SymEngine::DenseMatrix*>"(const MatrixBase *a)
@@ -788,18 +794,16 @@ cdef extern from "<symengine/matrix.h>" namespace "SymEngine":
788794
const DenseMatrix &x, DenseMatrix &result) nogil except +
789795
void diff "SymEngine::sdiff"(const DenseMatrix &A,
790796
RCP[const Basic] &x, DenseMatrix &result) nogil except +
791-
void row_join(const DenseMatrix &A, const DenseMatrix &B, DenseMatrix &C) nogil
792-
void col_join(const DenseMatrix &A, const DenseMatrix &B, DenseMatrix &C) nogil
793-
void row_del(DenseMatrix &A, unsigned k) nogil
794-
void col_del(DenseMatrix &A, unsigned k) nogil
795-
void row_exchange_dense(DenseMatrix &A, unsigned i, unsigned j) nogil
796-
void column_exchange_dense(DenseMatrix &A, unsigned i, unsigned j) nogil
797-
void dot(const DenseMatrix &A, const DenseMatrix &B, DenseMatrix &C) nogil
798-
void cross(const DenseMatrix &A, const DenseMatrix &B, DenseMatrix &C) nogil
799797
void eye (DenseMatrix &A, int k) nogil
800798
void diag(DenseMatrix &A, vec_basic &v, int k) nogil
801799
void ones(DenseMatrix &A) nogil
802800
void zeros(DenseMatrix &A) nogil
801+
void row_exchange_dense(DenseMatrix &A, unsigned i, unsigned j) nogil
802+
void row_mul_scalar_dense(DenseMatrix &A, unsigned i, RCP[const Basic] &c) nogil
803+
void row_add_row_dense(DenseMatrix &A, unsigned i, unsigned j, RCP[const Basic] &c) nogil
804+
void column_exchange_dense(DenseMatrix &A, unsigned i, unsigned j) nogil
805+
void dot(const DenseMatrix &A, const DenseMatrix &B, DenseMatrix &C) nogil
806+
void cross(const DenseMatrix &A, const DenseMatrix &B, DenseMatrix &C) nogil
803807

804808
cdef extern from "<symengine/ntheory.h>" namespace "SymEngine":
805809
int probab_prime_p(const Integer &a, int reps)

symengine/lib/symengine_wrapper.pyx

Lines changed: 76 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2505,39 +2505,62 @@ cdef class DenseMatrixBase(MatrixBase):
25052505
cdef DenseMatrixBase o = sympify(rhs)
25062506
if self.rows != o.rows:
25072507
raise ShapeError("`self` and `rhs` must have the same number of rows.")
2508-
cdef DenseMatrixBase result = self.__class__(self.rows, self.cols + o.cols)
2509-
symengine.row_join(deref(symengine.static_cast_DenseMatrix(self.thisptr)),
2510-
deref(symengine.static_cast_DenseMatrix(o.thisptr)),
2511-
deref(symengine.static_cast_DenseMatrix(result.thisptr)))
2512-
return result
2508+
cdef DenseMatrixBase d = self.__class__(self)
2509+
deref(symengine.static_cast_DenseMatrix(d.thisptr)).row_join(deref(symengine.static_cast_DenseMatrix(o.thisptr)))
2510+
return d
25132511

25142512
def col_join(self, bott):
25152513
cdef DenseMatrixBase o = sympify(bott)
25162514
if self.cols != o.cols:
25172515
raise ShapeError("`self` and `rhs` must have the same number of columns.")
2518-
cdef DenseMatrixBase result = self.__class__(self.rows + o.rows, self.cols)
2519-
symengine.col_join(deref(symengine.static_cast_DenseMatrix(self.thisptr)),
2520-
deref(symengine.static_cast_DenseMatrix(o.thisptr)),
2521-
deref(symengine.static_cast_DenseMatrix(result.thisptr)))
2522-
return result
2516+
cdef DenseMatrixBase d = self.__class__(self)
2517+
deref(symengine.static_cast_DenseMatrix(d.thisptr)).col_join(deref(symengine.static_cast_DenseMatrix(o.thisptr)))
2518+
return d
2519+
2520+
def row_insert(self, pos, bott):
2521+
cdef DenseMatrixBase o = sympify(bott)
2522+
if pos < 0:
2523+
pos = self.rows + pos
2524+
if pos < 0:
2525+
pos = 0
2526+
elif pos > self.rows:
2527+
pos = self.rows
2528+
if self.cols != o.cols:
2529+
raise ShapeError("`self` and `other` must have the same number of columns.")
2530+
cdef DenseMatrixBase d = self.__class__(self)
2531+
deref(symengine.static_cast_DenseMatrix(d.thisptr)).row_insert(deref(symengine.static_cast_DenseMatrix(o.thisptr)), pos)
2532+
return d
2533+
2534+
def col_insert(self, pos, bott):
2535+
cdef DenseMatrixBase o = sympify(bott)
2536+
if pos < 0:
2537+
pos = self.cols + pos
2538+
if pos < 0:
2539+
pos = 0
2540+
elif pos > self.cols:
2541+
pos = self.cols
2542+
if self.rows != o.rows:
2543+
raise ShapeError("`self` and `other` must have the same number of rows.")
2544+
cdef DenseMatrixBase d = self.__class__(self)
2545+
deref(symengine.static_cast_DenseMatrix(d.thisptr)).col_insert(deref(symengine.static_cast_DenseMatrix(o.thisptr)), pos)
2546+
return d
25232547

25242548
def dot(self, b):
25252549
cdef DenseMatrixBase o = sympify(b)
25262550
cdef DenseMatrixBase result = self.__class__(self.rows, self.cols)
2527-
symengine.dot(deref(symengine.static_cast_DenseMatrix(self.thisptr)),
2528-
deref(symengine.static_cast_DenseMatrix(o.thisptr)),
2529-
deref(symengine.static_cast_DenseMatrix(result.thisptr)))
2530-
return result
2551+
symengine.dot(deref(symengine.static_cast_DenseMatrix(self.thisptr)), deref(symengine.static_cast_DenseMatrix(o.thisptr)), deref(symengine.static_cast_DenseMatrix(result.thisptr)))
2552+
if len(result) == 1:
2553+
return result[0, 0]
2554+
else:
2555+
return result
25312556

25322557
def cross(self, b):
25332558
cdef DenseMatrixBase o = sympify(b)
25342559
if self.cols * self.rows != 3 or o.cols * o.rows != 3:
25352560
raise ShapeError("Dimensions incorrect for cross product: %s x %s" %
25362561
((self.rows, self.cols), (b.rows, b.cols)))
25372562
cdef DenseMatrixBase result = self.__class__(self.rows, self.cols)
2538-
symengine.cross(deref(symengine.static_cast_DenseMatrix(self.thisptr)),
2539-
deref(symengine.static_cast_DenseMatrix(o.thisptr)),
2540-
deref(symengine.static_cast_DenseMatrix(result.thisptr)))
2563+
symengine.cross(deref(symengine.static_cast_DenseMatrix(self.thisptr)), deref(symengine.static_cast_DenseMatrix(o.thisptr)), deref(symengine.static_cast_DenseMatrix(result.thisptr)))
25412564
return result
25422565

25432566
@property
@@ -2548,6 +2571,10 @@ cdef class DenseMatrixBase(MatrixBase):
25482571
def cols(self):
25492572
return self.ncols()
25502573

2574+
@property
2575+
def is_square(self):
2576+
return self.rows == self.cols
2577+
25512578
def nrows(self):
25522579
return deref(self.thisptr).nrows()
25532580

@@ -2596,6 +2623,12 @@ cdef class DenseMatrixBase(MatrixBase):
25962623
# No error checking is done
25972624
return c2py(deref(self.thisptr).get(i, j))
25982625

2626+
def col(self, j):
2627+
return self[:, j]
2628+
2629+
def row(self, i):
2630+
return self[i, :]
2631+
25992632
def set(self, i, j, e):
26002633
i, j = self._get_index(i, j)
26012634
return self._set(i, j, e)
@@ -2672,6 +2705,13 @@ cdef class DenseMatrixBase(MatrixBase):
26722705
deref(out.thisptr).set(i, j, e_.thisptr)
26732706
return out
26742707

2708+
def _applyfunc(self, f):
2709+
cdef int nr = self.nrows()
2710+
cdef int nc = self.ncols()
2711+
for i in range(nr):
2712+
for j in range(nc):
2713+
self._set(i, j, f(self._get(i, j)))
2714+
26752715
def msubs(self, *args):
26762716
cdef _DictBasic D = get_dict(*args)
26772717
return self.applyfunc(lambda x: x.msubs(D))
@@ -2833,6 +2873,9 @@ cdef class DenseMatrixBase(MatrixBase):
28332873
def tolist(self):
28342874
return self[:]
28352875

2876+
def _mat(self):
2877+
return self
2878+
28362879
def atoms(self, *types):
28372880
if types:
28382881
s = set()
@@ -2871,40 +2914,43 @@ class DenseMatrixBaseIter(object):
28712914
cdef class MutableDenseMatrix(DenseMatrixBase):
28722915

28732916
def col_swap(self, i, j):
2874-
symengine.column_exchange_dense(deref(symengine.static_cast_DenseMatrix(self.thisptr)),
2875-
i, j)
2917+
symengine.column_exchange_dense(deref(symengine.static_cast_DenseMatrix(self.thisptr)), i, j)
28762918

28772919
def fill(self, value):
28782920
for i in range(self.rows):
28792921
for j in range(self.cols):
28802922
self[i, j] = value
28812923

28822924
def row_swap(self, i, j):
2883-
symengine.row_exchange_dense(deref(symengine.static_cast_DenseMatrix(self.thisptr)),
2884-
i, j)
2925+
symengine.row_exchange_dense(deref(symengine.static_cast_DenseMatrix(self.thisptr)), i, j)
2926+
2927+
def rowmul(self, i, c, *args):
2928+
cdef Basic _c = sympify(c)
2929+
symengine.row_mul_scalar_dense(deref(symengine.static_cast_DenseMatrix(self.thisptr)), i, _c.thisptr)
2930+
return self
2931+
2932+
def rowadd(self, i, j, c, *args):
2933+
cdef Basic _c = sympify(c)
2934+
symengine.row_add_row_dense(deref(symengine.static_cast_DenseMatrix(self.thisptr)), i, j, _c.thisptr)
2935+
return self
28852936

28862937
def row_del(self, i):
28872938
if i < -self.rows or i >= self.rows:
28882939
raise IndexError("Index out of range: 'i = %s', valid -%s <= i"
28892940
" < %s" % (i, self.rows, self.rows))
28902941
if i < 0:
28912942
i += self.rows
2892-
symengine.row_del(deref(symengine.static_cast_DenseMatrix(self.thisptr)), i)
2943+
deref(symengine.static_cast_DenseMatrix(self.thisptr)).row_del(i)
2944+
return self
28932945

28942946
def col_del(self, i):
28952947
if i < -self.cols or i >= self.cols:
28962948
raise IndexError("Index out of range: 'i=%s', valid -%s <= i < %s"
28972949
% (i, self.cols, self.cols))
28982950
if i < 0:
28992951
i += self.cols
2900-
symengine.col_del(deref(symengine.static_cast_DenseMatrix(self.thisptr)), i)
2901-
2902-
def _applyfunc(self, f):
2903-
cdef int nr = self.nrows()
2904-
cdef int nc = self.ncols()
2905-
for i in range(nr):
2906-
for j in range(nc):
2907-
self._set(i, j, f(self._get(i, j)))
2952+
deref(symengine.static_cast_DenseMatrix(self.thisptr)).col_del(i)
2953+
return self
29082954

29092955
Matrix = DenseMatrix = MutableDenseMatrix
29102956

@@ -2913,12 +2959,6 @@ cdef class ImmutableDenseMatrix(DenseMatrixBase):
29132959
def __setitem__(self, key, value):
29142960
raise TypeError("Cannot set values of {}".format(self.__class__))
29152961

2916-
def set(self, i, j, e):
2917-
raise TypeError("Cannot set values of {}".format(self.__class__))
2918-
2919-
def _set(self, i, j, e):
2920-
raise TypeError("Cannot set values of {}".format(self.__class__))
2921-
29222962
ImmutableMatrix = ImmutableDenseMatrix
29232963

29242964
cdef matrix_to_vec(DenseMatrixBase d, symengine.vec_basic& v):

symengine/tests/test_matrices.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,95 @@ def test_row_swap():
369369
assert A == B
370370

371371

372+
def test_row_col_del():
373+
e = DenseMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9])
374+
raises(IndexError, lambda: e.row_del(5))
375+
raises(IndexError, lambda: e.row_del(-5))
376+
raises(IndexError, lambda: e.col_del(5))
377+
raises(IndexError, lambda: e.col_del(-5))
378+
379+
assert e.row_del(-1) == DenseMatrix([[1, 2, 3], [4, 5, 6]])
380+
assert e.col_del(-1) == DenseMatrix([[1, 2], [4, 5]])
381+
382+
e = DenseMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9])
383+
assert e.row_del(1) == DenseMatrix([[1, 2, 3], [7, 8, 9]])
384+
assert e.col_del(1) == DenseMatrix([[1, 3], [7, 9]])
385+
386+
387+
def test_row_join():
388+
assert eye(3).row_join(DenseMatrix([7, 7, 7])) == \
389+
DenseMatrix([[1, 0, 0, 7],
390+
[0, 1, 0, 7],
391+
[0, 0, 1, 7]])
392+
393+
394+
def test_col_join():
395+
assert eye(3).col_join(DenseMatrix([[7, 7, 7]])) == \
396+
DenseMatrix([[1, 0, 0],
397+
[0, 1, 0],
398+
[0, 0, 1],
399+
[7, 7, 7]])
400+
401+
402+
def test_row_insert():
403+
M = zeros(3)
404+
V = ones(1, 3)
405+
assert M.row_insert(1, V) == DenseMatrix([[0, 0, 0],
406+
[1, 1, 1],
407+
[0, 0, 0],
408+
[0, 0, 0]])
409+
410+
411+
def test_col_insert():
412+
M = zeros(3)
413+
V = ones(3, 1)
414+
assert M.col_insert(1, V) == DenseMatrix([[0, 1, 0, 0],
415+
[0, 1, 0, 0],
416+
[0, 1, 0, 0]])
417+
418+
419+
def test_rowmul():
420+
M = ones(3)
421+
assert M.rowmul(2, 2) == DenseMatrix([[1, 1, 1],
422+
[1, 1, 1],
423+
[2, 2, 2]])
424+
425+
426+
def test_rowadd():
427+
M = ones(3)
428+
assert M.rowadd(2, 1, 1) == DenseMatrix([[1, 1, 1],
429+
[1, 1, 1],
430+
[2, 2, 2]])
431+
432+
433+
def test_row_col():
434+
m = DenseMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9])
435+
assert m.row(0) == DenseMatrix(1, 3, [1, 2, 3])
436+
assert m.col(0) == DenseMatrix(3, 1, [1, 4, 7])
437+
438+
439+
def test_is_square():
440+
m = DenseMatrix([[1],[1]])
441+
m2 = DenseMatrix([[2, 2], [2, 2]])
442+
assert not m.is_square
443+
assert m2.is_square
444+
445+
446+
def test_dot():
447+
A = DenseMatrix(2, 3, [1, 2, 3, 4, 5, 6])
448+
B = DenseMatrix(2, 1, [7, 8])
449+
assert A.dot(B) == DenseMatrix(1, 3, [39, 54, 69])
450+
assert ones(1, 3).dot(ones(3, 1)) == 3
451+
452+
453+
def test_cross():
454+
M = DenseMatrix(1, 3, [1, 2, 3])
455+
V = DenseMatrix(1, 3, [3, 4, 5])
456+
assert M.cross(V) == DenseMatrix(1, 3, [-2, 4, -2])
457+
raises(ShapeError, lambda:
458+
DenseMatrix(1, 2, [1, 1]).cross(DenseMatrix(1, 2, [1, 1])))
459+
460+
372461
def test_immutablematrix():
373462
A = ImmutableMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
374463
assert A.shape == (3, 3)

symengine_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
d13fec95c651bbce195988a6d9a146e9b726b2c2
1+
978dfd1656d5dbc722f5dc448bbe96e97b2d7be9

0 commit comments

Comments
 (0)