Skip to content

Commit 3541ffc

Browse files
authored
Merge pull request #172 from ShikharJ/Misc
Wrap Miscellaneous Functions
2 parents 69c12f8 + 81104e2 commit 3541ffc

File tree

7 files changed

+166
-16
lines changed

7 files changed

+166
-16
lines changed

symengine/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
UndefFunction, Function, FunctionSymbol as AppliedUndef,
1010
have_numpy, true, false, Equality, Unequality, GreaterThan,
1111
LessThan, StrictGreaterThan, StrictLessThan, Eq, Ne, Ge, Le,
12-
Gt, Lt, GoldenRatio, Catalan, EulerGamma)
12+
Gt, Lt, GoldenRatio, Catalan, EulerGamma, Dummy, perfect_power,
13+
integer_nthroot, isprime, sqrt_mod)
1314
from .utilities import var, symbols
1415
from .functions import *
1516

symengine/lib/symengine.pxd

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
227227
bool is_a_Rational "SymEngine::is_a<SymEngine::Rational>"(const Basic &b) nogil
228228
bool is_a_Complex "SymEngine::is_a<SymEngine::Complex>"(const Basic &b) nogil
229229
bool is_a_Symbol "SymEngine::is_a<SymEngine::Symbol>"(const Basic &b) nogil
230+
bool is_a_Dummy "SymEngine::is_a<SymEngine::Dummy>"(const Basic &b) nogil
230231
bool is_a_Constant "SymEngine::is_a<SymEngine::Constant>"(const Basic &b) nogil
231232
bool is_a_Infty "SymEngine::is_a<SymEngine::Infty>"(const Basic &b) nogil
232233
bool is_a_NaN "SymEngine::is_a<SymEngine::NaN>"(const Basic &b) nogil
@@ -305,6 +306,8 @@ cdef extern from "<symengine/symbol.h>" namespace "SymEngine":
305306
cdef cppclass Symbol(Basic):
306307
Symbol(string name) nogil
307308
string get_name() nogil
309+
cdef cppclass Dummy(Symbol):
310+
pass
308311

309312
cdef extern from "<symengine/number.h>" namespace "SymEngine":
310313
cdef cppclass Number(Basic):
@@ -339,6 +342,9 @@ cdef extern from "<symengine/integer.h>" namespace "SymEngine":
339342
integer_class as_mpz() nogil
340343
cdef RCP[const Integer] integer(int i) nogil
341344
cdef RCP[const Integer] integer(integer_class i) nogil
345+
int i_nth_root(const Ptr[RCP[Integer]] &r, const Integer &a, unsigned long int n) nogil
346+
bool perfect_square(const Integer &n) nogil
347+
bool perfect_power(const Integer &n) nogil
342348

343349
cdef extern from "<symengine/rational.h>" namespace "SymEngine":
344350
cdef cppclass Rational(Number):
@@ -414,20 +420,17 @@ cdef extern from "<symengine/pow.h>" namespace "SymEngine":
414420
cdef RCP[const Basic] pow(RCP[const Basic] &a, RCP[const Basic] &b) nogil except+
415421
cdef RCP[const Basic] sqrt(RCP[const Basic] &x) nogil except+
416422
cdef RCP[const Basic] exp(RCP[const Basic] &x) nogil except+
417-
cdef RCP[const Basic] log(RCP[const Basic] &x) nogil except+
418-
cdef RCP[const Basic] log(RCP[const Basic] &x, RCP[const Basic] &y) nogil except+
419423

420424
cdef cppclass Pow(Basic):
421425
RCP[const Basic] get_base() nogil
422426
RCP[const Basic] get_exp() nogil
423427

424-
cdef cppclass Log(Basic):
425-
RCP[const Basic] get_arg() nogil
426-
427428

428429
cdef extern from "<symengine/basic.h>" namespace "SymEngine":
429430
# We need to specialize these for our classes:
430431
RCP[const Basic] make_rcp_Symbol "SymEngine::make_rcp<const SymEngine::Symbol>"(string name) nogil
432+
RCP[const Basic] make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"() nogil
433+
RCP[const Basic] make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"(string name) nogil
431434
RCP[const Basic] make_rcp_PySymbol "SymEngine::make_rcp<const SymEngine::PySymbol>"(string name, PyObject * pyobj) nogil
432435
RCP[const Basic] make_rcp_Constant "SymEngine::make_rcp<const SymEngine::Constant>"(string name) nogil
433436
RCP[const Basic] make_rcp_Infty "SymEngine::make_rcp<const SymEngine::Infty>"(RCP[const Number] i) nogil
@@ -499,7 +502,8 @@ cdef extern from "<symengine/functions.h>" namespace "SymEngine":
499502
cdef RCP[const Basic] floor(RCP[const Basic] &x) nogil except+
500503
cdef RCP[const Basic] ceiling(RCP[const Basic] &x) nogil except+
501504
cdef RCP[const Basic] conjugate(RCP[const Basic] &x) nogil except+
502-
505+
cdef RCP[const Basic] log(RCP[const Basic] &x) nogil except+
506+
cdef RCP[const Basic] log(RCP[const Basic] &x, RCP[const Basic] &y) nogil except+
503507

504508
cdef cppclass Function(Basic):
505509
pass
@@ -667,6 +671,9 @@ cdef extern from "<symengine/functions.h>" namespace "SymEngine":
667671
cdef cppclass Conjugate(OneArgFunction):
668672
pass
669673

674+
cdef cppclass Log(Function):
675+
pass
676+
670677
IF HAVE_SYMENGINE_MPFR:
671678
cdef extern from "mpfr.h":
672679
ctypedef struct __mpfr_struct:

symengine/lib/symengine_wrapper.pyx

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ cdef c2py(RCP[const symengine.Basic] o):
4242
r = Number.__new__(Rational)
4343
elif (symengine.is_a_Complex(deref(o))):
4444
r = Complex.__new__(Complex)
45+
elif (symengine.is_a_Dummy(deref(o))):
46+
r = Symbol.__new__(Dummy)
4547
elif (symengine.is_a_Symbol(deref(o))):
4648
if (symengine.is_a_PySymbol(deref(o))):
4749
return <object>(deref(symengine.rcp_static_cast_PySymbol(o)).get_py_object())
@@ -197,6 +199,8 @@ def sympy2symengine(a, raise_error=False):
197199
from sympy.core.function import AppliedUndef as sympy_AppliedUndef
198200
if isinstance(a, sympy.Symbol):
199201
return Symbol(a.name)
202+
elif isinstance(a, sympy.Dummy):
203+
return Dummy(a.name)
200204
elif isinstance(a, sympy.Mul):
201205
return mul(*[sympy2symengine(x, raise_error) for x in a.args])
202206
elif isinstance(a, sympy.Add):
@@ -710,6 +714,8 @@ cdef class Basic(object):
710714
else:
711715
return eval(self, prec)
712716

717+
evalf = n
718+
713719
@property
714720
def args(self):
715721
cdef symengine.vec_basic args = deref(self.thisptr).get_args()
@@ -895,6 +901,27 @@ class Symbol(Basic):
895901
return self.__class__
896902

897903

904+
class Dummy(Symbol):
905+
906+
def __init__(Basic self, name=None, *args, **kwargs):
907+
if name is None:
908+
self.thisptr = symengine.make_rcp_Dummy()
909+
else:
910+
self.thisptr = symengine.make_rcp_Dummy(name.encode("utf-8"))
911+
912+
def _sympy_(self):
913+
import sympy
914+
return sympy.Dummy(str(self))
915+
916+
@property
917+
def is_Dummy(self):
918+
return True
919+
920+
@property
921+
def func(self):
922+
return self.__class__
923+
924+
898925
def symarray(prefix, shape, **kwargs):
899926
""" Creates an nd-array of symbols
900927
@@ -2956,6 +2983,8 @@ def diff(ex, *x):
29562983
def expand(x):
29572984
return sympify(x).expand()
29582985

2986+
expand_mul = expand
2987+
29592988
def function_symbol(name, *args):
29602989
cdef symengine.vec_basic v
29612990
cdef Basic e_
@@ -2973,6 +3002,23 @@ def exp(x):
29733002
cdef Basic X = sympify(x)
29743003
return c2py(symengine.exp(X.thisptr))
29753004

3005+
def perfect_power(n):
3006+
cdef Basic _n = sympify(n)
3007+
require(_n, Integer)
3008+
return symengine.perfect_power(deref(symengine.rcp_static_cast_Integer(_n.thisptr)))
3009+
3010+
def is_square(n):
3011+
cdef Basic _n = sympify(n)
3012+
require(_n, Integer)
3013+
return symengine.perfect_square(deref(symengine.rcp_static_cast_Integer(_n.thisptr)))
3014+
3015+
def integer_nthroot(a, n):
3016+
cdef Basic _a = sympify(a)
3017+
require(_a, Integer)
3018+
cdef RCP[const symengine.Integer] _r
3019+
cdef int ret_val = symengine.i_nth_root(symengine.outArg_Integer(_r), deref(symengine.rcp_static_cast_Integer(_a.thisptr)), n)
3020+
return (c2py(<RCP[const symengine.Basic]>_r), ret_val == 1)
3021+
29763022
def _max(*args):
29773023
cdef symengine.vec_basic v
29783024
cdef Basic e_
@@ -3103,6 +3149,8 @@ def probab_prime_p(n, reps = 25):
31033149
require(_n, Integer)
31043150
return symengine.probab_prime_p(deref(symengine.rcp_static_cast_Integer(_n.thisptr)), reps) >= 1
31053151

3152+
isprime = probab_prime_p
3153+
31063154
def nextprime(n):
31073155
cdef Basic _n = sympify(n)
31083156
require(_n, Integer)
@@ -3132,7 +3180,9 @@ def gcd_ext(a, b):
31323180
cdef RCP[const symengine.Integer] g, s, t
31333181
symengine.gcd_ext(symengine.outArg_Integer(g), symengine.outArg_Integer(s), symengine.outArg_Integer(t),
31343182
deref(symengine.rcp_static_cast_Integer(_a.thisptr)), deref(symengine.rcp_static_cast_Integer(_b.thisptr)))
3135-
return [c2py(<RCP[const symengine.Basic]>g), c2py(<RCP[const symengine.Basic]>s), c2py(<RCP[const symengine.Basic]>t)]
3183+
return (c2py(<RCP[const symengine.Basic]>s), c2py(<RCP[const symengine.Basic]>t), c2py(<RCP[const symengine.Basic]>g))
3184+
3185+
igcdex = gcd_ext
31363186

31373187
def mod(a, b):
31383188
if b == 0:
@@ -3418,6 +3468,11 @@ def nthroot_mod_list(a, n, m):
34183468
s.append(c2py(<RCP[const symengine.Basic]>(root_list[i])))
34193469
return s
34203470

3471+
def sqrt_mod(a, p, all_roots=False):
3472+
if all_roots:
3473+
return nthroot_mod_list(a, 2, p)
3474+
return nthroot_mod(a, 2, p)
3475+
34213476
def powermod(a, b, m):
34223477
cdef Basic _a = sympify(a)
34233478
cdef Basic _m = sympify(m)

symengine/tests/test_ntheory.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
from symengine.utilities import raises
22

3-
from symengine.lib.symengine_wrapper import (probab_prime_p, nextprime, gcd,
3+
from symengine.lib.symengine_wrapper import (isprime, nextprime, gcd,
44
lcm, gcd_ext, mod, quotient, quotient_mod, mod_inverse, crt, fibonacci,
55
fibonacci2, lucas, lucas2, binomial, factorial, divides, factor,
66
factor_lehman_method, factor_pollard_pm1_method, factor_pollard_rho_method,
77
prime_factors, prime_factor_multiplicities, Sieve, Sieve_iterator,
88
bernoulli, primitive_root, primitive_root_list, totient, carmichael,
99
multiplicative_order, legendre, jacobi, kronecker, nthroot_mod,
10-
nthroot_mod_list, powermod, powermod_list, Integer)
10+
nthroot_mod_list, powermod, powermod_list, Integer, sqrt_mod)
1111

1212

1313
def test_probab_prime_p():
14-
assert probab_prime_p(101) is True
15-
assert probab_prime_p(100) is False
14+
s = set(Sieve.generate_primes(1000))
15+
for n in range(1001):
16+
assert (n in s) == isprime(n)
1617

1718

1819
def test_nextprime():
@@ -32,10 +33,13 @@ def test_lcm():
3233

3334

3435
def test_gcd_ext():
35-
[p, q, r] = gcd_ext(6, 9)
36+
(q, r, p) = gcd_ext(6, 9)
3637
assert p == q * 6 + r * 9
37-
[p, q, r] = gcd_ext(-15, 10)
38+
(q, r, p) = gcd_ext(-15, 10)
3839
assert p == q * -15 + r * 10
40+
assert gcd_ext(2, 3) == (-1, 1, 1)
41+
assert gcd_ext(10, 12) == (-1, 1, 2)
42+
assert gcd_ext(100, 2004) == (-20, 1, 4)
3943

4044

4145
def test_mod():
@@ -212,6 +216,22 @@ def test_nthroot_mod():
212216
assert nthroot_mod(3, 2, 5) is None
213217

214218

219+
def test_sqrt_mod():
220+
assert sqrt_mod(3, 13) == 9
221+
assert sqrt_mod(6, 23) == 12
222+
assert sqrt_mod(345, 690) == 345
223+
assert sqrt_mod(9, 27, True) == [3, 6, 12, 15, 21, 24]
224+
assert sqrt_mod(9, 81, True) == [3, 24, 30, 51, 57, 78]
225+
assert sqrt_mod(9, 3**5, True) == [3, 78, 84, 159, 165, 240]
226+
assert sqrt_mod(81, 3**4, True) == [0, 9, 18, 27, 36, 45, 54, 63, 72]
227+
assert sqrt_mod(81, 3**5, True) == [9, 18, 36, 45, 63, 72, 90, 99, 117,\
228+
126, 144, 153, 171, 180, 198, 207, 225, 234]
229+
assert sqrt_mod(81, 3**6, True) == [9, 72, 90, 153, 171, 234, 252, 315,\
230+
333, 396, 414, 477, 495, 558, 576, 639, 657, 720]
231+
assert sqrt_mod(81, 3**7, True) == [9, 234, 252, 477, 495, 720, 738, 963,\
232+
981, 1206, 1224, 1449, 1467, 1692, 1710, 1935, 1953, 2178]
233+
234+
215235
def test_nthroot_mod_list():
216236
assert nthroot_mod_list(-4, 4, 65) == [4, 6, 7, 9, 17, 19, 22, 32,
217237
33, 43, 46, 48, 56, 58, 59, 61]

symengine/tests/test_number.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from symengine.utilities import raises
22

33
from symengine import Integer, I
4+
from symengine.lib.symengine_wrapper import (perfect_power, is_square, integer_nthroot)
45

56

67
def test_integer():
@@ -63,3 +64,57 @@ def test_is_conditions():
6364
assert not i.is_nonpositive
6465
assert not i.is_nonnegative
6566
assert i.is_complex
67+
68+
69+
def test_perfect_power():
70+
assert perfect_power(1) == True
71+
assert perfect_power(7) == False
72+
assert perfect_power(8) == True
73+
assert perfect_power(9) == True
74+
assert perfect_power(10) == False
75+
assert perfect_power(1024) == True
76+
assert perfect_power(1025) == False
77+
assert perfect_power(6**7) == True
78+
assert perfect_power(-27) == True
79+
assert perfect_power(-64) == True
80+
assert perfect_power(-32) == True
81+
82+
83+
def test_perfect_square():
84+
assert is_square(7) == False
85+
assert is_square(8) == False
86+
assert is_square(9) == True
87+
assert is_square(10) == False
88+
assert perfect_power(49) == True
89+
assert perfect_power(50) == False
90+
91+
92+
def test_integer_nthroot():
93+
assert integer_nthroot(1, 2) == (1, True)
94+
assert integer_nthroot(1, 5) == (1, True)
95+
assert integer_nthroot(2, 1) == (2, True)
96+
assert integer_nthroot(2, 2) == (1, False)
97+
assert integer_nthroot(2, 5) == (1, False)
98+
assert integer_nthroot(4, 2) == (2, True)
99+
assert integer_nthroot(123**25, 25) == (123, True)
100+
assert integer_nthroot(123**25 + 1, 25) == (123, False)
101+
assert integer_nthroot(123**25 - 1, 25) == (122, False)
102+
assert integer_nthroot(1, 1) == (1, True)
103+
assert integer_nthroot(0, 1) == (0, True)
104+
assert integer_nthroot(0, 3) == (0, True)
105+
assert integer_nthroot(10000, 1) == (10000, True)
106+
assert integer_nthroot(4, 2) == (2, True)
107+
assert integer_nthroot(16, 2) == (4, True)
108+
assert integer_nthroot(26, 2) == (5, False)
109+
assert integer_nthroot(1234567**7, 7) == (1234567, True)
110+
assert integer_nthroot(1234567**7 + 1, 7) == (1234567, False)
111+
assert integer_nthroot(1234567**7 - 1, 7) == (1234566, False)
112+
b = 25**1000
113+
assert integer_nthroot(b, 1000) == (25, True)
114+
assert integer_nthroot(b + 1, 1000) == (25, False)
115+
assert integer_nthroot(b - 1, 1000) == (24, False)
116+
c = 10**400
117+
c2 = c**2
118+
assert integer_nthroot(c2, 2) == (c, True)
119+
assert integer_nthroot(c2 + 1, 2) == (c, False)
120+
assert integer_nthroot(c2 - 1, 2) == (c - 1, False)

symengine/tests/test_symbol.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from symengine import Symbol, symbols, symarray, has_symbol
1+
from symengine import Symbol, symbols, symarray, has_symbol, Dummy
22
from symengine.utilities import raises
33

44

@@ -147,3 +147,15 @@ def test_has_symbol():
147147
assert not has_symbol(2, a)
148148
assert not has_symbol(c, a)
149149
assert has_symbol(a+b, b)
150+
151+
def test_dummy():
152+
x1 = Symbol('x')
153+
x2 = Symbol('x')
154+
xdummy1 = Dummy('x')
155+
xdummy2 = Dummy('x')
156+
157+
assert x1 == x2
158+
assert x1 != xdummy1
159+
assert xdummy1 != xdummy2
160+
assert Dummy() != Dummy()
161+
assert Dummy('x') != Dummy('x')

symengine_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3e87d58a9404535fbc2bb8d09357faf4ab8eb5a1
1+
d13fec95c651bbce195988a6d9a146e9b726b2c2

0 commit comments

Comments
 (0)