Skip to content

Commit a25ccb4

Browse files
authored
Merge pull request #307 from isuruf/linsolve
Linsolve
2 parents 740b26f + c2597bd commit a25ccb4

File tree

5 files changed

+64
-38
lines changed

5 files changed

+64
-38
lines changed

symengine/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
LessThan, StrictGreaterThan, StrictLessThan, Eq, Ne, Ge, Le,
1313
Gt, Lt, And, Or, Not, Nand, Nor, Xor, Xnor, perfect_power, integer_nthroot,
1414
isprime, sqrt_mod, Expr, cse, count_ops, ccode, Piecewise, Contains, Interval, FiniteSet,
15+
EmptySet, linsolve,
1516
FunctionSymbol as AppliedUndef,
1617
golden_ratio as GoldenRatio,
1718
catalan as Catalan,

symengine/lib/symengine.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
171171
void insert(iterator, iterator) except +
172172

173173
ctypedef vector[rcp_const_basic] vec_basic "SymEngine::vec_basic"
174+
ctypedef vector[RCP[Symbol]] vec_sym "SymEngine::vec_sym"
174175
ctypedef vector[RCP[Integer]] vec_integer "SymEngine::vec_integer"
175176
ctypedef map[RCP[Integer], unsigned] map_integer_uint "SymEngine::map_integer_uint"
176177
cdef struct RCPIntegerKeyLess
@@ -1047,6 +1048,7 @@ cdef extern from "<symengine/sets.h>" namespace "SymEngine":
10471048
cdef extern from "<symengine/solve.h>" namespace "SymEngine":
10481049
cdef RCP[const Set] solve(rcp_const_basic &f, RCP[const Symbol] &sym) nogil except +
10491050
cdef RCP[const Set] solve(rcp_const_basic &f, RCP[const Symbol] &sym, RCP[const Set] &domain) nogil except +
1051+
cdef vec_basic linsolve(const vec_basic &eqs, const vec_sym &syms) nogil except +
10501052

10511053
cdef extern from "<symengine/printers.h>" namespace "SymEngine":
10521054
string ccode(const Basic &x) nogil except +

symengine/lib/symengine_wrapper.pyx

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,11 +2068,7 @@ class Add(AssocOp):
20682068
identity = 0
20692069

20702070
def __new__(cls, *args, **kwargs):
2071-
cdef symengine.vec_basic v_
2072-
cdef Basic e
2073-
for e_ in args:
2074-
e = _sympify(e_)
2075-
v_.push_back(e.thisptr)
2071+
cdef symengine.vec_basic v_ = iter_to_vec_basic(args)
20762072
return c2py(symengine.add(v_))
20772073

20782074
@classmethod
@@ -2123,11 +2119,7 @@ class Mul(AssocOp):
21232119
identity = 1
21242120

21252121
def __new__(cls, *args, **kwargs):
2126-
cdef symengine.vec_basic v_
2127-
cdef Basic e
2128-
for e_ in args:
2129-
e = _sympify(e_)
2130-
v_.push_back(e.thisptr)
2122+
cdef symengine.vec_basic v_ = iter_to_vec_basic(args)
21312123
return c2py(symengine.mul(v_))
21322124

21332125
@classmethod
@@ -2296,11 +2288,7 @@ class KroneckerDelta(Function):
22962288

22972289
class LeviCivita(Function):
22982290
def __new__(cls, *args):
2299-
cdef symengine.vec_basic v
2300-
cdef Basic e_
2301-
for e in args:
2302-
e_ = sympify(e)
2303-
v.push_back(e_.thisptr)
2291+
cdef symengine.vec_basic v = iter_to_vec_basic(args)
23042292
return c2py(symengine.levi_civita(v))
23052293

23062294
def _sympy_(self):
@@ -2710,11 +2698,7 @@ class PyFunction(FunctionSymbol):
27102698
def __init__(Basic self, pyfunction = None, args = None, pyfunction_class=None, module=None):
27112699
if pyfunction is None:
27122700
return
2713-
cdef symengine.vec_basic v
2714-
cdef Basic arg_
2715-
for arg in args:
2716-
arg_ = sympify(arg)
2717-
v.push_back(arg_.thisptr)
2701+
cdef symengine.vec_basic v = iter_to_vec_basic(args)
27182702
cdef PyFunctionClass _pyfunction_class = get_function_class(pyfunction_class, module)
27192703
cdef PyObject* _pyfunction = <PyObject*>pyfunction
27202704
Py_XINCREF(_pyfunction)
@@ -3785,42 +3769,53 @@ cdef class ImmutableDenseMatrix(DenseMatrixBase):
37853769

37863770
ImmutableMatrix = ImmutableDenseMatrix
37873771

3772+
37883773
cdef matrix_to_vec(DenseMatrixBase d, symengine.vec_basic& v):
37893774
cdef Basic e_
37903775
for i in range(d.nrows()):
37913776
for j in range(d.ncols()):
37923777
e_ = d._get(i, j)
37933778
v.push_back(e_.thisptr)
37943779

3780+
37953781
def eye(n):
37963782
cdef DenseMatrixBase d = DenseMatrix(n, n)
37973783
symengine.eye(deref(symengine.static_cast_DenseMatrix(d.thisptr)), 0)
37983784
return d
37993785

3800-
def diag(*values):
3801-
cdef DenseMatrixBase d = DenseMatrix(len(values), len(values))
3802-
cdef symengine.vec_basic V
3786+
3787+
cdef symengine.vec_basic iter_to_vec_basic(iter):
38033788
cdef Basic B
3804-
for b in values:
3789+
cdef symengine.vec_basic V
3790+
for b in iter:
38053791
B = sympify(b)
38063792
V.push_back(B.thisptr)
3793+
return V
3794+
3795+
3796+
def diag(*values):
3797+
cdef DenseMatrixBase d = DenseMatrix(len(values), len(values))
3798+
cdef symengine.vec_basic V = iter_to_vec_basic(values)
38073799
symengine.diag(deref(symengine.static_cast_DenseMatrix(d.thisptr)), V, 0)
38083800
return d
38093801

3802+
38103803
def ones(r, c = None):
38113804
if c is None:
38123805
c = r
38133806
cdef DenseMatrixBase d = DenseMatrix(r, c)
38143807
symengine.ones(deref(symengine.static_cast_DenseMatrix(d.thisptr)))
38153808
return d
38163809

3810+
38173811
def zeros(r, c = None):
38183812
if c is None:
38193813
c = r
38203814
cdef DenseMatrixBase d = DenseMatrix(r, c)
38213815
symengine.zeros(deref(symengine.static_cast_DenseMatrix(d.thisptr)))
38223816
return d
38233817

3818+
38243819
cdef class Sieve:
38253820
@staticmethod
38263821
def generate_primes(n):
@@ -3831,6 +3826,7 @@ cdef class Sieve:
38313826
s.append(primes[i])
38323827
return s
38333828

3829+
38343830
cdef class Sieve_iterator:
38353831
cdef symengine.sieve_iterator *thisptr
38363832
cdef unsigned limit
@@ -5000,6 +4996,25 @@ def solve(f, sym, domain=None):
50004996
return c2py(<rcp_const_basic>(symengine.solve(f_.thisptr, x, d)))
50014997

50024998

4999+
def linsolve(eqs, syms):
5000+
"""
5001+
Solve a set of linear equations given as an iterable `eqs`
5002+
which are linear w.r.t the symbols given as an iterable `syms`
5003+
"""
5004+
cdef symengine.vec_basic eqs_ = iter_to_vec_basic(eqs)
5005+
cdef symengine.vec_sym syms_
5006+
cdef RCP[const symengine.Symbol] sym_
5007+
cdef Symbol B
5008+
for sym in syms:
5009+
B = sympify(sym)
5010+
sym_ = symengine.rcp_static_cast_Symbol(B.thisptr)
5011+
syms_.push_back(sym_)
5012+
if syms_.size() != eqs_.size():
5013+
raise RuntimeError("Number of equations and symbols do not match")
5014+
cdef symengine.vec_basic ret = symengine.linsolve(eqs_, syms_)
5015+
return vec_basic_to_tuple(ret)
5016+
5017+
50035018
def cse(exprs):
50045019
cdef symengine.vec_basic vec
50055020
cdef symengine.vec_pair replacements

symengine/tests/test_solve.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
from symengine.utilities import raises
2-
from symengine.lib.symengine_wrapper import (Interval, EmptySet, FiniteSet,
3-
I, oo, solve, Eq, Symbol)
2+
from symengine import (Interval, EmptySet, FiniteSet, I, oo, Eq, Symbol,
3+
linsolve)
4+
from symengine.lib.symengine_wrapper import solve
45

56
def test_solve():
6-
x = Symbol("x")
7-
reals = Interval(-oo, oo)
7+
x = Symbol("x")
8+
reals = Interval(-oo, oo)
89

9-
assert solve(1, x, reals) == EmptySet()
10-
assert solve(0, x, reals) == reals
11-
assert solve(x + 3, x, reals) == FiniteSet(-3)
12-
assert solve(x + 3, x, Interval(0, oo)) == EmptySet()
13-
assert solve(x, x, reals) == FiniteSet(0)
14-
assert solve(x**2 + 1, x) == FiniteSet(-I, I)
15-
assert solve(x**2 - 2*x + 1, x) == FiniteSet(1)
16-
assert solve(Eq(x**3 + 3*x**2 + 3*x, -1), x, reals) == FiniteSet(-1)
17-
assert solve(x**3 - x, x) == FiniteSet(0, 1, -1)
10+
assert solve(1, x, reals) == EmptySet()
11+
assert solve(0, x, reals) == reals
12+
assert solve(x + 3, x, reals) == FiniteSet(-3)
13+
assert solve(x + 3, x, Interval(0, oo)) == EmptySet()
14+
assert solve(x, x, reals) == FiniteSet(0)
15+
assert solve(x**2 + 1, x) == FiniteSet(-I, I)
16+
assert solve(x**2 - 2*x + 1, x) == FiniteSet(1)
17+
assert solve(Eq(x**3 + 3*x**2 + 3*x, -1), x, reals) == FiniteSet(-1)
18+
assert solve(x**3 - x, x) == FiniteSet(0, 1, -1)
19+
20+
def test_linsolve():
21+
x = Symbol("x")
22+
y = Symbol("y")
23+
assert linsolve([x - 2], [x]) == (2,)
24+
assert linsolve([x - 2, y - 3], [x, y]) == (2, 3)
25+
assert linsolve([x + y - 3, x + 2*y - 4], [x, y]) == (2, 1)

symengine_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
v0.5.0
1+
fc05f8d55915c2de956e7797d764eb1116b61711

0 commit comments

Comments
 (0)