Skip to content

Commit cf2d943

Browse files
authored
Merge pull request #143 from ShikharJ/DenseMatrix
Additions to DenseMatrixBase and MutableDenseMatrix
2 parents 54c3053 + ac2b834 commit cf2d943

File tree

4 files changed

+202
-38
lines changed

4 files changed

+202
-38
lines changed

symengine/lib/symengine.pxd

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,12 @@ cdef extern from "<symengine/matrix.h>" namespace "SymEngine":
769769
DenseMatrix(unsigned i, unsigned j) nogil
770770
DenseMatrix(unsigned i, unsigned j, const vec_basic &v) nogil
771771
void resize(unsigned i, unsigned j) nogil
772+
void row_join(const DenseMatrix &B) nogil
773+
void col_join(const DenseMatrix &B) nogil
774+
void row_insert(const DenseMatrix &B, unsigned pos) nogil
775+
void col_insert(const DenseMatrix &B, unsigned pos) nogil
776+
void row_del(unsigned k) nogil
777+
void col_del(unsigned k) nogil
772778

773779
bool is_a_DenseMatrix "SymEngine::is_a<SymEngine::DenseMatrix>"(const MatrixBase &b) nogil
774780
DenseMatrix* static_cast_DenseMatrix "static_cast<SymEngine::DenseMatrix*>"(const MatrixBase *a)
@@ -792,6 +798,12 @@ cdef extern from "<symengine/matrix.h>" namespace "SymEngine":
792798
void diag(DenseMatrix &A, vec_basic &v, int k) nogil
793799
void ones(DenseMatrix &A) nogil
794800
void zeros(DenseMatrix &A) nogil
801+
void row_exchange_dense(DenseMatrix &A, unsigned i, unsigned j) nogil
802+
void row_mul_scalar_dense(DenseMatrix &A, unsigned i, RCP[const Basic] &c) nogil
803+
void row_add_row_dense(DenseMatrix &A, unsigned i, unsigned j, RCP[const Basic] &c) nogil
804+
void column_exchange_dense(DenseMatrix &A, unsigned i, unsigned j) nogil
805+
void dot(const DenseMatrix &A, const DenseMatrix &B, DenseMatrix &C) nogil
806+
void cross(const DenseMatrix &A, const DenseMatrix &B, DenseMatrix &C) nogil
795807

796808
cdef extern from "<symengine/ntheory.h>" namespace "SymEngine":
797809
int probab_prime_p(const Integer &a, int reps)

symengine/lib/symengine_wrapper.pyx

Lines changed: 100 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2505,32 +2505,62 @@ cdef class DenseMatrixBase(MatrixBase):
25052505
cdef DenseMatrixBase o = sympify(rhs)
25062506
if self.rows != o.rows:
25072507
raise ShapeError("`self` and `rhs` must have the same number of rows.")
2508-
cdef DenseMatrixBase result = self.__class__(self.rows, self.cols + o.cols)
2509-
cdef Basic e_
2510-
for i in range(self.rows):
2511-
for j in range(self.cols):
2512-
e_ = self._get(i, j)
2513-
deref(result.thisptr).set(i, j, e_.thisptr)
2514-
for i in range(o.rows):
2515-
for j in range(o.cols):
2516-
e_ = sympify(o._get(i, j))
2517-
deref(result.thisptr).set(i, j + self.cols, e_.thisptr)
2518-
return result
2508+
cdef DenseMatrixBase d = self.__class__(self)
2509+
deref(symengine.static_cast_DenseMatrix(d.thisptr)).row_join(deref(symengine.static_cast_DenseMatrix(o.thisptr)))
2510+
return d
25192511

25202512
def col_join(self, bott):
25212513
cdef DenseMatrixBase o = sympify(bott)
25222514
if self.cols != o.cols:
25232515
raise ShapeError("`self` and `rhs` must have the same number of columns.")
2524-
cdef DenseMatrixBase result = self.__class__(self.rows + o.rows, self.cols)
2525-
cdef Basic e_
2526-
for i in range(self.rows):
2527-
for j in range(self.cols):
2528-
e_ = self._get(i, j)
2529-
deref(result.thisptr).set(i, j, e_.thisptr)
2530-
for i in range(o.rows):
2531-
for j in range(o.cols):
2532-
e_ = sympify(o._get(i, j))
2533-
deref(result.thisptr).set(i + self.rows, j, e_.thisptr)
2516+
cdef DenseMatrixBase d = self.__class__(self)
2517+
deref(symengine.static_cast_DenseMatrix(d.thisptr)).col_join(deref(symengine.static_cast_DenseMatrix(o.thisptr)))
2518+
return d
2519+
2520+
def row_insert(self, pos, bott):
2521+
cdef DenseMatrixBase o = sympify(bott)
2522+
if pos < 0:
2523+
pos = self.rows + pos
2524+
if pos < 0:
2525+
pos = 0
2526+
elif pos > self.rows:
2527+
pos = self.rows
2528+
if self.cols != o.cols:
2529+
raise ShapeError("`self` and `other` must have the same number of columns.")
2530+
cdef DenseMatrixBase d = self.__class__(self)
2531+
deref(symengine.static_cast_DenseMatrix(d.thisptr)).row_insert(deref(symengine.static_cast_DenseMatrix(o.thisptr)), pos)
2532+
return d
2533+
2534+
def col_insert(self, pos, bott):
2535+
cdef DenseMatrixBase o = sympify(bott)
2536+
if pos < 0:
2537+
pos = self.cols + pos
2538+
if pos < 0:
2539+
pos = 0
2540+
elif pos > self.cols:
2541+
pos = self.cols
2542+
if self.rows != o.rows:
2543+
raise ShapeError("`self` and `other` must have the same number of rows.")
2544+
cdef DenseMatrixBase d = self.__class__(self)
2545+
deref(symengine.static_cast_DenseMatrix(d.thisptr)).col_insert(deref(symengine.static_cast_DenseMatrix(o.thisptr)), pos)
2546+
return d
2547+
2548+
def dot(self, b):
2549+
cdef DenseMatrixBase o = sympify(b)
2550+
cdef DenseMatrixBase result = self.__class__(self.rows, self.cols)
2551+
symengine.dot(deref(symengine.static_cast_DenseMatrix(self.thisptr)), deref(symengine.static_cast_DenseMatrix(o.thisptr)), deref(symengine.static_cast_DenseMatrix(result.thisptr)))
2552+
if len(result) == 1:
2553+
return result[0, 0]
2554+
else:
2555+
return result
2556+
2557+
def cross(self, b):
2558+
cdef DenseMatrixBase o = sympify(b)
2559+
if self.cols * self.rows != 3 or o.cols * o.rows != 3:
2560+
raise ShapeError("Dimensions incorrect for cross product: %s x %s" %
2561+
((self.rows, self.cols), (b.rows, b.cols)))
2562+
cdef DenseMatrixBase result = self.__class__(self.rows, self.cols)
2563+
symengine.cross(deref(symengine.static_cast_DenseMatrix(self.thisptr)), deref(symengine.static_cast_DenseMatrix(o.thisptr)), deref(symengine.static_cast_DenseMatrix(result.thisptr)))
25342564
return result
25352565

25362566
@property
@@ -2541,6 +2571,10 @@ cdef class DenseMatrixBase(MatrixBase):
25412571
def cols(self):
25422572
return self.ncols()
25432573

2574+
@property
2575+
def is_square(self):
2576+
return self.rows == self.cols
2577+
25442578
def nrows(self):
25452579
return deref(self.thisptr).nrows()
25462580

@@ -2589,6 +2623,12 @@ cdef class DenseMatrixBase(MatrixBase):
25892623
# No error checking is done
25902624
return c2py(deref(self.thisptr).get(i, j))
25912625

2626+
def col(self, j):
2627+
return self[:, j]
2628+
2629+
def row(self, i):
2630+
return self[i, :]
2631+
25922632
def set(self, i, j, e):
25932633
i, j = self._get_index(i, j)
25942634
return self._set(i, j, e)
@@ -2665,6 +2705,13 @@ cdef class DenseMatrixBase(MatrixBase):
26652705
deref(out.thisptr).set(i, j, e_.thisptr)
26662706
return out
26672707

2708+
def _applyfunc(self, f):
2709+
cdef int nr = self.nrows()
2710+
cdef int nc = self.ncols()
2711+
for i in range(nr):
2712+
for j in range(nc):
2713+
self._set(i, j, f(self._get(i, j)))
2714+
26682715
def msubs(self, *args):
26692716
cdef _DictBasic D = get_dict(*args)
26702717
return self.applyfunc(lambda x: x.msubs(D))
@@ -2826,6 +2873,9 @@ cdef class DenseMatrixBase(MatrixBase):
28262873
def tolist(self):
28272874
return self[:]
28282875

2876+
def _mat(self):
2877+
return self
2878+
28292879
def atoms(self, *types):
28302880
if types:
28312881
s = set()
@@ -2864,24 +2914,43 @@ class DenseMatrixBaseIter(object):
28642914
cdef class MutableDenseMatrix(DenseMatrixBase):
28652915

28662916
def col_swap(self, i, j):
2867-
for k in range(0, self.rows):
2868-
self[k, i], self[k, j] = self[k, j], self[k, i]
2917+
symengine.column_exchange_dense(deref(symengine.static_cast_DenseMatrix(self.thisptr)), i, j)
28692918

28702919
def fill(self, value):
28712920
for i in range(self.rows):
28722921
for j in range(self.cols):
28732922
self[i, j] = value
28742923

28752924
def row_swap(self, i, j):
2876-
for k in range(0, self.cols):
2877-
self[i, k], self[j, k] = self[j, k], self[i, k]
2925+
symengine.row_exchange_dense(deref(symengine.static_cast_DenseMatrix(self.thisptr)), i, j)
28782926

2879-
def _applyfunc(self, f):
2880-
cdef int nr = self.nrows()
2881-
cdef int nc = self.ncols()
2882-
for i in range(nr):
2883-
for j in range(nc):
2884-
self._set(i, j, f(self._get(i, j)))
2927+
def rowmul(self, i, c, *args):
2928+
cdef Basic _c = sympify(c)
2929+
symengine.row_mul_scalar_dense(deref(symengine.static_cast_DenseMatrix(self.thisptr)), i, _c.thisptr)
2930+
return self
2931+
2932+
def rowadd(self, i, j, c, *args):
2933+
cdef Basic _c = sympify(c)
2934+
symengine.row_add_row_dense(deref(symengine.static_cast_DenseMatrix(self.thisptr)), i, j, _c.thisptr)
2935+
return self
2936+
2937+
def row_del(self, i):
2938+
if i < -self.rows or i >= self.rows:
2939+
raise IndexError("Index out of range: 'i = %s', valid -%s <= i"
2940+
" < %s" % (i, self.rows, self.rows))
2941+
if i < 0:
2942+
i += self.rows
2943+
deref(symengine.static_cast_DenseMatrix(self.thisptr)).row_del(i)
2944+
return self
2945+
2946+
def col_del(self, i):
2947+
if i < -self.cols or i >= self.cols:
2948+
raise IndexError("Index out of range: 'i=%s', valid -%s <= i < %s"
2949+
% (i, self.cols, self.cols))
2950+
if i < 0:
2951+
i += self.cols
2952+
deref(symengine.static_cast_DenseMatrix(self.thisptr)).col_del(i)
2953+
return self
28852954

28862955
Matrix = DenseMatrix = MutableDenseMatrix
28872956

@@ -2890,12 +2959,6 @@ cdef class ImmutableDenseMatrix(DenseMatrixBase):
28902959
def __setitem__(self, key, value):
28912960
raise TypeError("Cannot set values of {}".format(self.__class__))
28922961

2893-
def set(self, i, j, e):
2894-
raise TypeError("Cannot set values of {}".format(self.__class__))
2895-
2896-
def _set(self, i, j, e):
2897-
raise TypeError("Cannot set values of {}".format(self.__class__))
2898-
28992962
ImmutableMatrix = ImmutableDenseMatrix
29002963

29012964
cdef matrix_to_vec(DenseMatrixBase d, symengine.vec_basic& v):

symengine/tests/test_matrices.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,95 @@ def test_row_swap():
369369
assert A == B
370370

371371

372+
def test_row_col_del():
373+
e = DenseMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9])
374+
raises(IndexError, lambda: e.row_del(5))
375+
raises(IndexError, lambda: e.row_del(-5))
376+
raises(IndexError, lambda: e.col_del(5))
377+
raises(IndexError, lambda: e.col_del(-5))
378+
379+
assert e.row_del(-1) == DenseMatrix([[1, 2, 3], [4, 5, 6]])
380+
assert e.col_del(-1) == DenseMatrix([[1, 2], [4, 5]])
381+
382+
e = DenseMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9])
383+
assert e.row_del(1) == DenseMatrix([[1, 2, 3], [7, 8, 9]])
384+
assert e.col_del(1) == DenseMatrix([[1, 3], [7, 9]])
385+
386+
387+
def test_row_join():
388+
assert eye(3).row_join(DenseMatrix([7, 7, 7])) == \
389+
DenseMatrix([[1, 0, 0, 7],
390+
[0, 1, 0, 7],
391+
[0, 0, 1, 7]])
392+
393+
394+
def test_col_join():
395+
assert eye(3).col_join(DenseMatrix([[7, 7, 7]])) == \
396+
DenseMatrix([[1, 0, 0],
397+
[0, 1, 0],
398+
[0, 0, 1],
399+
[7, 7, 7]])
400+
401+
402+
def test_row_insert():
403+
M = zeros(3)
404+
V = ones(1, 3)
405+
assert M.row_insert(1, V) == DenseMatrix([[0, 0, 0],
406+
[1, 1, 1],
407+
[0, 0, 0],
408+
[0, 0, 0]])
409+
410+
411+
def test_col_insert():
412+
M = zeros(3)
413+
V = ones(3, 1)
414+
assert M.col_insert(1, V) == DenseMatrix([[0, 1, 0, 0],
415+
[0, 1, 0, 0],
416+
[0, 1, 0, 0]])
417+
418+
419+
def test_rowmul():
420+
M = ones(3)
421+
assert M.rowmul(2, 2) == DenseMatrix([[1, 1, 1],
422+
[1, 1, 1],
423+
[2, 2, 2]])
424+
425+
426+
def test_rowadd():
427+
M = ones(3)
428+
assert M.rowadd(2, 1, 1) == DenseMatrix([[1, 1, 1],
429+
[1, 1, 1],
430+
[2, 2, 2]])
431+
432+
433+
def test_row_col():
434+
m = DenseMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9])
435+
assert m.row(0) == DenseMatrix(1, 3, [1, 2, 3])
436+
assert m.col(0) == DenseMatrix(3, 1, [1, 4, 7])
437+
438+
439+
def test_is_square():
440+
m = DenseMatrix([[1],[1]])
441+
m2 = DenseMatrix([[2, 2], [2, 2]])
442+
assert not m.is_square
443+
assert m2.is_square
444+
445+
446+
def test_dot():
447+
A = DenseMatrix(2, 3, [1, 2, 3, 4, 5, 6])
448+
B = DenseMatrix(2, 1, [7, 8])
449+
assert A.dot(B) == DenseMatrix(1, 3, [39, 54, 69])
450+
assert ones(1, 3).dot(ones(3, 1)) == 3
451+
452+
453+
def test_cross():
454+
M = DenseMatrix(1, 3, [1, 2, 3])
455+
V = DenseMatrix(1, 3, [3, 4, 5])
456+
assert M.cross(V) == DenseMatrix(1, 3, [-2, 4, -2])
457+
raises(ShapeError, lambda:
458+
DenseMatrix(1, 2, [1, 1]).cross(DenseMatrix(1, 2, [1, 1])))
459+
460+
372461
def test_immutablematrix():
373462
A = ImmutableMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
374463
assert A.shape == (3, 3)

symengine_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
d13fec95c651bbce195988a6d9a146e9b726b2c2
1+
978dfd1656d5dbc722f5dc448bbe96e97b2d7be9

0 commit comments

Comments
 (0)