Skip to content

Commit 374d69c

Browse files
committed
Wrap cse
1 parent 3fd1e48 commit 374d69c

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

symengine/lib/symengine.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
219219
ctypedef unordered_map[int, rcp_const_basic].iterator umap_int_basic_iterator "SymEngine::umap_int_basic::iterator"
220220
ctypedef unordered_map[rcp_const_basic, rcp_const_number] umap_basic_num "SymEngine::umap_basic_num"
221221
ctypedef unordered_map[rcp_const_basic, rcp_const_number].iterator umap_basic_num_iterator "SymEngine::umap_basic_num::iterator"
222+
ctypedef vector[pair[rcp_const_basic, rcp_const_basic]] vec_pair "SymEngine::vec_pair"
222223

223224
bool eq(const Basic &a, const Basic &b) nogil except +
224225
bool neq(const Basic &a, const Basic &b) nogil except +
@@ -312,6 +313,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
312313
bool is_a_Xor "SymEngine::is_a<SymEngine::Xor>"(const Basic &b) nogil
313314
RCP[const Basic] expand(RCP[const Basic] &o, bool deep) nogil except +
314315
void as_numer_denom(RCP[const Basic] &x, const Ptr[RCP[Basic]] &numer, const Ptr[RCP[Basic]] &denom) nogil
316+
void cse(vec_pair &replacements, vec_basic &reduced_exprs, const vec_basic &exprs) nogil except +
315317

316318
cdef extern from "<symengine/subs.h>" namespace "SymEngine":
317319
RCP[const Basic] msubs (RCP[const Basic] &x, const map_basic_basic &x) nogil

symengine/lib/symengine_wrapper.pyx

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from cython.operator cimport dereference as deref, preincrement as inc
22
cimport symengine
3-
from symengine cimport RCP, pair, map_basic_basic, umap_int_basic, umap_int_basic_iterator, umap_basic_num, umap_basic_num_iterator, rcp_const_basic, std_pair_short_rcp_const_basic, rcp_const_seriescoeffinterface
3+
from symengine cimport (RCP, pair, map_basic_basic, umap_int_basic,
4+
umap_int_basic_iterator, umap_basic_num, umap_basic_num_iterator,
5+
rcp_const_basic, std_pair_short_rcp_const_basic,
6+
rcp_const_seriescoeffinterface)
47
from libcpp cimport bool as cppbool
58
from libcpp.string cimport string
69
from libcpp.vector cimport vector
@@ -751,10 +754,24 @@ def get_dict(*args):
751754

752755

753756
cdef tuple vec_basic_to_tuple(symengine.vec_basic& vec):
757+
return tuple(vec_basic_to_list(vec))
758+
759+
760+
cdef list vec_basic_to_list(symengine.vec_basic& vec):
754761
result = []
755762
for i in range(vec.size()):
756763
result.append(c2py(<RCP[const symengine.Basic]>(vec[i])))
757-
return tuple(result)
764+
return result
765+
766+
767+
cdef list vec_pair_to_list(symengine.vec_pair& vec):
768+
result = []
769+
cdef RCP[const symengine.Basic] a, b
770+
for i in range(vec.size()):
771+
a = <RCP[const symengine.Basic]>vec[i].first
772+
b = <RCP[const symengine.Basic]>vec[i].second
773+
result.append((c2py(a), c2py(b)))
774+
return result
758775

759776

760777
cdef class Basic(object):
@@ -4789,5 +4806,17 @@ def solve(f, sym, domain=None):
47894806
return c2py(<RCP[const symengine.Basic]>(symengine.solve(f_.thisptr, x, d)))
47904807

47914808

4809+
def cse(exprs):
4810+
cdef symengine.vec_basic vec
4811+
cdef symengine.vec_pair replacements
4812+
cdef symengine.vec_basic reduced_exprs
4813+
cdef Basic b
4814+
for expr in exprs:
4815+
b = sympify(expr)
4816+
vec.push_back(b.thisptr)
4817+
symengine.cse(replacements, reduced_exprs, vec)
4818+
return (vec_pair_to_list(replacements), vec_basic_to_list(reduced_exprs))
4819+
4820+
47924821
# Turn on nice stacktraces:
47934822
symengine.print_stack_on_segfault()

0 commit comments

Comments
 (0)