Skip to content

Commit a2c0379

Browse files
committed
minimal set of changes for returning failure to find non-zero pivot in FFGJ
1 parent b0f15af commit a2c0379

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

symengine/lib/symengine.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ cdef extern from "<symengine/matrix.h>" namespace "SymEngine":
687687
DenseMatrix &B) except+ nogil
688688
void FFLU_solve "SymEngine::fraction_free_LU_solve"(const DenseMatrix &A,
689689
const DenseMatrix &b, DenseMatrix &x) except+ nogil
690-
void FFGJ_solve "SymEngine::fraction_free_gauss_jordan_solve"(const DenseMatrix &A,
690+
int FFGJ_solve "SymEngine::fraction_free_gauss_jordan_solve"(const DenseMatrix &A,
691691
const DenseMatrix &b, DenseMatrix &x) except+ nogil
692692
void LDL_solve "SymEngine::LDL_solve"(const DenseMatrix &A, const DenseMatrix &b,
693693
DenseMatrix &x) except+ nogil

symengine/lib/symengine_wrapper.in.pyx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3998,9 +3998,11 @@ cdef class DenseMatrixBase(MatrixBase):
39983998
deref(symengine.static_cast_DenseMatrix(b_.thisptr)),
39993999
deref(symengine.static_cast_DenseMatrix(x.thisptr)))
40004000
elif method.upper() == 'FFGJ':
4001-
symengine.FFGJ_solve(deref(symengine.static_cast_DenseMatrix(self.thisptr)),
4001+
failure = symengine.FFGJ_solve(deref(symengine.static_cast_DenseMatrix(self.thisptr)),
40024002
deref(symengine.static_cast_DenseMatrix(b_.thisptr)),
40034003
deref(symengine.static_cast_DenseMatrix(x.thisptr)))
4004+
if (failure != 0):
4005+
raise Exception("Underdetermined system. Failed to find non-zero pivot in column: %d" % failure)
40044006
else:
40054007
raise Exception("Unsupported method.")
40064008

symengine/tests/test_matrices.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
have_numpy = True
1212
except ImportError:
1313
have_numpy = False
14-
14+
1515
try:
1616
import sympy
1717
from sympy.core.cache import clear_cache
@@ -444,6 +444,15 @@ def test_solve():
444444
x = A.solve(b, 'FFGJ')
445445
assert x == y
446446

447+
B = DenseMatrix(2, 2, [0]*4)
448+
try:
449+
B.solve(b, 'FFGJ')
450+
except Exception:
451+
pass
452+
else:
453+
raise Exception("this operation should have raised an exception")
454+
455+
447456

448457
def test_FFLU():
449458
A = DenseMatrix(4, 4, [1, 2, 3, 4, 2, 2, 3, 4, 3, 3, 3, 4, 9, 8, 7, 6])

0 commit comments

Comments
 (0)