Skip to content

Commit c9f975f

Browse files
committed
sympy_compat: Add atan2.
1 parent 9ec090d commit c9f975f

File tree

5 files changed

+31
-1
lines changed

5 files changed

+31
-1
lines changed

symengine/lib/symengine.pxd

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
256256
bool is_a_ComplexMPC "SymEngine::is_a<SymEngine::ComplexMPC>"(const Basic &b) nogil
257257
bool is_a_Log "SymEngine::is_a<SymEngine::Log>"(const Basic &b) nogil
258258
bool is_a_PyNumber "SymEngine::is_a<SymEngine::PyNumber>"(const Basic &b) nogil
259+
bool is_a_ATan2 "SymEngine::is_a<SymEngine::ATan2>"(const Basic &b) nogil
259260

260261
RCP[const Basic] expand(RCP[const Basic] &o) nogil except +
261262

@@ -417,6 +418,7 @@ cdef extern from "<symengine/functions.h>" namespace "SymEngine":
417418
cdef RCP[const Basic] function_symbol(string name, const vec_basic &arg) nogil except+
418419
cdef RCP[const Basic] abs(RCP[const Basic] &arg) nogil except+
419420
cdef RCP[const Basic] gamma(RCP[const Basic] &arg) nogil except+
421+
cdef RCP[const Basic] atan2(RCP[const Basic] &num, RCP[const Basic] &den) nogil except+
420422

421423
cdef cppclass Function(Basic):
422424
pass
@@ -512,6 +514,9 @@ cdef extern from "<symengine/functions.h>" namespace "SymEngine":
512514
cdef cppclass Gamma(Function):
513515
pass
514516

517+
cdef cppclass ATan2(Function):
518+
pass
519+
515520
IF HAVE_SYMENGINE_MPFR:
516521
cdef extern from "mpfr.h":
517522
ctypedef struct __mpfr_struct:

symengine/lib/symengine_wrapper.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ cdef class Abs(Function):
121121
cdef class Gamma(Function):
122122
pass
123123

124+
cdef class ATan2(Function):
125+
pass
126+
124127
cdef class Derivative(Basic):
125128
pass
126129

symengine/lib/symengine_wrapper.pyx

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ cdef c2py(RCP[const symengine.Basic] o):
9999
r = ATanh.__new__(ATanh)
100100
elif (symengine.is_a_ACoth(deref(o))):
101101
r = ACoth.__new__(ACoth)
102+
elif (symengine.is_a_ATan2(deref(o))):
103+
r = ATan2.__new__(ATan2)
102104
elif (symengine.is_a_PyNumber(deref(o))):
103105
r = PyNumber.__new__(PyNumber)
104106
else:
@@ -169,6 +171,8 @@ def sympy2symengine(a, raise_error=False):
169171
return acsc(a.args[0])
170172
elif isinstance(a, sympy.asec):
171173
return asec(a.args[0])
174+
elif isinstance(a, sympy.atan2):
175+
return atan2(*a.args)
172176
elif isinstance(a, sympy.functions.elementary.hyperbolic.HyperbolicFunction):
173177
if isinstance(a, sympy.sinh):
174178
return sinh(a.args[0])
@@ -2159,6 +2163,11 @@ def gamma(x):
21592163
cdef Basic X = sympify(x)
21602164
return c2py(symengine.gamma(X.thisptr))
21612165

2166+
def atan2(x, y):
2167+
cdef Basic X = sympify(x)
2168+
cdef Basic Y = sympify(y)
2169+
return c2py(symengine.atan2(X.thisptr, Y.thisptr))
2170+
21622171
def eval_double(x):
21632172
cdef Basic X = sympify(x)
21642173
return c2py(<RCP[const symengine.Basic]>(symengine.real_double(symengine.eval_double(deref(X.thisptr)))))

symengine/sympy_compat.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,12 @@ def __new__(cls, a):
240240
return symengine.acoth(a)
241241

242242

243+
class atan2(_RegisteredFunction):
244+
_classes = (symengine.ATan2,)
245+
246+
def __new__(cls, a, b):
247+
return symengine.atan2(a, b)
248+
243249
'''
244250
for i in ("""Sin Cos Tan Gamma Cot Csc Sec ASin ACos ATan
245251
ACot ACsc ASec Sinh Cosh Tanh Coth ASinh ACosh ATanh

symengine/tests/test_sympy_compat.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from symengine.sympy_compat import (Integer, Rational, S, Basic, Add, Mul,
2-
Pow, symbols, Symbol, log, sin, zeros)
2+
Pow, symbols, Symbol, log, sin, zeros, atan2)
33

44
def test_Integer():
55
i = Integer(5)
@@ -50,6 +50,13 @@ def test_log():
5050
i = log(x)
5151
assert isinstance(i, log)
5252

53+
def test_ATan2():
54+
x, y = symbols("x y")
55+
i = atan2(x, y)
56+
assert isinstance(i, atan2)
57+
i = atan2(0, 1)
58+
assert i == 0
59+
5360
def test_zeros():
5461
assert zeros(3, c=2).shape == (3, 2)
5562

0 commit comments

Comments
 (0)