11using SpecialFunctions
22
3- function IMPLEMENT_ONE_ARG_FUNC (meth, symnm; lib= :basic_ )
3+ function IMPLEMENT_ONE_ARG_FUNC (modu, meth, symnm; lib= :basic_ )
4+ methbang = Symbol (meth, " !" )
5+ if isa (modu, Symbol)
6+ meth = :($ modu.$ meth)
7+ end
48 @eval begin
59 function ($ meth)(b:: SymbolicType )
610 a = Basic ()
11+ ($ methbang)(a,b)
12+ a
13+ end
14+ function ($ methbang)(a:: Basic , b:: SymbolicType )
715 err_code = ccall (($ (string (lib,symnm)), libsymengine), Cuint, (Ref{Basic}, Ref{Basic}), a, b)
816 throw_if_error (err_code, $ (string (meth)))
917 return a
1018 end
1119 end
1220end
1321
14- function IMPLEMENT_TWO_ARG_FUNC (meth, symnm; lib= :basic_ )
22+ function IMPLEMENT_TWO_ARG_FUNC (modu, meth, symnm; lib= :basic_ )
23+ methbang = Symbol (meth, " !" )
24+ if isa (modu, Symbol)
25+ meth = :($ modu.$ meth)
26+ end
27+
1528 @eval begin
1629 function ($ meth)(b1:: SymbolicType , b2:: Number )
1730 a = Basic ()
31+ ($ methbang)(a,b1,b2)
32+ a
33+ end
34+ function ($ methbang)(a:: Basic , b1:: SymbolicType , b2:: Number )
1835 b1, b2 = promote (b1, b2)
1936 err_code = ccall (($ (string (lib,symnm)), libsymengine), Cuint, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b1, b2)
2037 throw_if_error (err_code, $ (string (meth)))
2138 return a
2239 end
40+
2341 end
2442end
2543
@@ -59,7 +77,7 @@ for (meth, libnm, modu) in [
5977 (:floor , :floor , :Base )
6078 ]
6179 eval (:(import $ modu.$ meth))
62- IMPLEMENT_ONE_ARG_FUNC (:( $ modu. $ meth) , libnm)
80+ IMPLEMENT_ONE_ARG_FUNC (modu, meth, libnm)
6381end
6482
6583for (meth, libnm, modu) in [
@@ -71,7 +89,7 @@ for (meth, libnm, modu) in [
7189 (:erfc , :erfc , :SpecialFunctions )
7290]
7391 eval (:(import $ modu.$ meth))
74- IMPLEMENT_ONE_ARG_FUNC (:( $ modu. $ meth) , libnm)
92+ IMPLEMENT_ONE_ARG_FUNC (modu, meth, libnm)
7593end
7694
7795for (meth, libnm, modu) in [
@@ -80,23 +98,31 @@ for (meth, libnm, modu) in [
8098 (:loggamma ,:loggamma ,:SpecialFunctions ),
8199 ]
82100 eval (:(import $ modu.$ meth))
83- IMPLEMENT_TWO_ARG_FUNC (:( $ modu. $ meth) , libnm)
101+ IMPLEMENT_TWO_ARG_FUNC (modu, meth, libnm)
84102end
85103
86- Base. abs2 (x:: SymEngine.Basic ) = abs (x)^ 2
87-
104+ function abs2! (a:: Basic , x:: Basic )
105+ abs! (a, x)
106+ mul! (a, a, a)
107+ a
108+ end
109+ function Base. abs2 (x:: Basic )
110+ a = Basic ()
111+ abs2! (a, x)
112+ a
113+ end
88114
89115
90116if get_symbol (:basic_atan2 ) != C_NULL
91117 import Base. atan
92- IMPLEMENT_TWO_ARG_FUNC (:( Base. atan) , :atan2 )
118+ IMPLEMENT_TWO_ARG_FUNC (:Base , : atan , :atan2 )
93119end
94120
95121# export not import
96122for (meth, libnm) in [
97123 (:lambertw ,:lambertw ), # in add-on packages, not base
98124 ]
99- IMPLEMENT_ONE_ARG_FUNC (meth, libnm)
125+ IMPLEMENT_ONE_ARG_FUNC (nothing , meth, libnm)
100126 eval (Expr (:export , meth))
101127end
102128
@@ -118,7 +144,7 @@ for (meth, libnm) in [(:gcd, :gcd),
118144 (:mod , :mod_f ),
119145 ]
120146 eval (:(import Base.$ meth))
121- IMPLEMENT_TWO_ARG_FUNC (:( Base. $ meth) , libnm, lib= :ntheory_ )
147+ IMPLEMENT_TWO_ARG_FUNC (:Base , meth, libnm; lib= :ntheory_ )
122148end
123149
124150Base. binomial (n:: Basic , k:: Number ) = binomial (N (n), N (k)) # ntheory_binomial seems wrong
@@ -129,18 +155,18 @@ Base.factorial(n::SymbolicType, k) = factorial(N(n), N(k))
129155# # but not (:fibonacci,:fibonacci), (:lucas, :lucas) (Basic type is not the signature)
130156for (meth, libnm) in [(:nextprime ,:nextprime )
131157 ]
132- IMPLEMENT_ONE_ARG_FUNC (meth, libnm, lib= :ntheory_ )
158+ IMPLEMENT_ONE_ARG_FUNC (nothing , meth, libnm, lib= :ntheory_ )
133159 eval (Expr (:export , meth))
134160end
135161
136162" Return coefficient of `x^n` term, `x` a symbol"
137- function coeff (b:: Basic , x, n)
138- c = Basic ()
163+ function coeff! (a:: Basic , b:: Basic , x, n)
139164 out = ccall ((:basic_coeff , libsymengine), Nothing,
140165 (Ref{Basic},Ref{Basic},Ref{Basic},Ref{Basic}),
141- c ,b,Basic (x), Basic (n))
142- c
166+ a ,b,Basic (x), Basic (n))
167+ a
143168end
169+ coeff (b:: Basic , x, n) = coeff! (Basic (), b, x, n)
144170
145171function Base. convert (:: Type{CVecBasic} , x:: Vector{T} ) where T
146172 vec = CVecBasic ()
0 commit comments