Skip to content

Commit b03f044

Browse files
authored
Merge pull request #193 from isuruf/cse2
Wrap cse
2 parents 3fd1e48 + 2aab90a commit b03f044

File tree

7 files changed

+56
-7
lines changed

7 files changed

+56
-7
lines changed

appveyor.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ install:
8181

8282
- set PATH=C:\Python%PYTHON_VERSION%;C:\Python%PYTHON_VERSION%\Scripts;%PATH%
8383
- pip install nose pytest
84-
- pip install --install-option="--no-cython-compile" cython
84+
- if [%COMPILER%]==[MinGW-w64] pip install --install-option="--no-cython-compile" cython==0.26
85+
- if NOT [%COMPILER%]==[MinGW-w64] pip install --install-option="--no-cython-compile" cython
8586
- if NOT [%WITH_NUMPY%]==[no] pip install numpy
8687
- if NOT [%WITH_SYMPY%]==[no] pip install sympy
8788

symengine/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
LessThan, StrictGreaterThan, StrictLessThan, Eq, Ne, Ge, Le,
1212
Gt, Lt, golden_ratio as GoldenRatio, catalan as Catalan,
1313
eulergamma as EulerGamma, Dummy, perfect_power, integer_nthroot,
14-
isprime, sqrt_mod, Expr)
14+
isprime, sqrt_mod, Expr, cse)
1515
from .utilities import var, symbols
1616
from .functions import *
1717

symengine/lib/pywrapper.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ hash_t PyFunction::__hash__() const {
261261
bool PyFunction::__eq__(const Basic &o) const {
262262
if (is_a<PyFunction>(o) and
263263
pyfunction_class_->__eq__(*static_cast<const PyFunction &>(o).get_pyfunction_class()) and
264-
unified_eq(arg_, static_cast<const PyFunction &>(o).arg_))
264+
unified_eq(get_vec(), static_cast<const PyFunction &>(o).get_vec()))
265265
return true;
266266
return false;
267267
}
@@ -271,7 +271,7 @@ int PyFunction::compare(const Basic &o) const {
271271
const PyFunction &s = static_cast<const PyFunction &>(o);
272272
int cmp = pyfunction_class_->compare(*s.get_pyfunction_class());
273273
if (cmp != 0) return cmp;
274-
return unified_compare(arg_, s.arg_);
274+
return unified_compare(get_vec(), s.get_vec());
275275
}
276276

277277
} // SymEngine

symengine/lib/symengine.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
219219
ctypedef unordered_map[int, rcp_const_basic].iterator umap_int_basic_iterator "SymEngine::umap_int_basic::iterator"
220220
ctypedef unordered_map[rcp_const_basic, rcp_const_number] umap_basic_num "SymEngine::umap_basic_num"
221221
ctypedef unordered_map[rcp_const_basic, rcp_const_number].iterator umap_basic_num_iterator "SymEngine::umap_basic_num::iterator"
222+
ctypedef vector[pair[rcp_const_basic, rcp_const_basic]] vec_pair "SymEngine::vec_pair"
222223

223224
bool eq(const Basic &a, const Basic &b) nogil except +
224225
bool neq(const Basic &a, const Basic &b) nogil except +
@@ -312,6 +313,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
312313
bool is_a_Xor "SymEngine::is_a<SymEngine::Xor>"(const Basic &b) nogil
313314
RCP[const Basic] expand(RCP[const Basic] &o, bool deep) nogil except +
314315
void as_numer_denom(RCP[const Basic] &x, const Ptr[RCP[Basic]] &numer, const Ptr[RCP[Basic]] &denom) nogil
316+
void cse(vec_pair &replacements, vec_basic &reduced_exprs, const vec_basic &exprs) nogil except +
315317

316318
cdef extern from "<symengine/subs.h>" namespace "SymEngine":
317319
RCP[const Basic] msubs (RCP[const Basic] &x, const map_basic_basic &x) nogil

symengine/lib/symengine_wrapper.pyx

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from cython.operator cimport dereference as deref, preincrement as inc
22
cimport symengine
3-
from symengine cimport RCP, pair, map_basic_basic, umap_int_basic, umap_int_basic_iterator, umap_basic_num, umap_basic_num_iterator, rcp_const_basic, std_pair_short_rcp_const_basic, rcp_const_seriescoeffinterface
3+
from symengine cimport (RCP, pair, map_basic_basic, umap_int_basic,
4+
umap_int_basic_iterator, umap_basic_num, umap_basic_num_iterator,
5+
rcp_const_basic, std_pair_short_rcp_const_basic,
6+
rcp_const_seriescoeffinterface)
47
from libcpp cimport bool as cppbool
58
from libcpp.string cimport string
69
from libcpp.vector cimport vector
@@ -751,10 +754,24 @@ def get_dict(*args):
751754

752755

753756
cdef tuple vec_basic_to_tuple(symengine.vec_basic& vec):
757+
return tuple(vec_basic_to_list(vec))
758+
759+
760+
cdef list vec_basic_to_list(symengine.vec_basic& vec):
754761
result = []
755762
for i in range(vec.size()):
756763
result.append(c2py(<RCP[const symengine.Basic]>(vec[i])))
757-
return tuple(result)
764+
return result
765+
766+
767+
cdef list vec_pair_to_list(symengine.vec_pair& vec):
768+
result = []
769+
cdef RCP[const symengine.Basic] a, b
770+
for i in range(vec.size()):
771+
a = <RCP[const symengine.Basic]>vec[i].first
772+
b = <RCP[const symengine.Basic]>vec[i].second
773+
result.append((c2py(a), c2py(b)))
774+
return result
758775

759776

760777
cdef class Basic(object):
@@ -4789,5 +4806,17 @@ def solve(f, sym, domain=None):
47894806
return c2py(<RCP[const symengine.Basic]>(symengine.solve(f_.thisptr, x, d)))
47904807

47914808

4809+
def cse(exprs):
4810+
cdef symengine.vec_basic vec
4811+
cdef symengine.vec_pair replacements
4812+
cdef symengine.vec_basic reduced_exprs
4813+
cdef Basic b
4814+
for expr in exprs:
4815+
b = sympify(expr)
4816+
vec.push_back(b.thisptr)
4817+
symengine.cse(replacements, reduced_exprs, vec)
4818+
return (vec_pair_to_list(replacements), vec_basic_to_list(reduced_exprs))
4819+
4820+
47924821
# Turn on nice stacktraces:
47934822
symengine.print_stack_on_segfault()

symengine/tests/test_cse.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from symengine import cse, sqrt, symbols
2+
3+
def test_cse_single():
4+
x, y, x0 = symbols("x, y, x0")
5+
e = pow(x + y, 2) + sqrt(x + y)
6+
substs, reduced = cse([e])
7+
assert substs == [(x0, x + y)]
8+
assert reduced == [sqrt(x0) + x0**2]
9+
10+
11+
def test_multiple_expressions():
12+
w, x, y, z, x0 = symbols("w, x, y, z, x0")
13+
e1 = (x + y)*z
14+
e2 = (x + y)*w
15+
substs, reduced = cse([e1, e2])
16+
assert substs == [(x0, x + y)]
17+
assert reduced == [x0*z, x0*w]

symengine_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
a2a839b5d8c4eab1560ed7ed56c805e1160344e3
1+
fdf132fcb4425589b69b40d60a90234944870b28

0 commit comments

Comments
 (0)