Skip to content

Commit d4f4c53

Browse files
authored
Merge pull request #287 from jverzani/julia_operation
Julia operation
2 parents d0eea29 + 1aa8317 commit d4f4c53

File tree

4 files changed

+56
-16
lines changed

4 files changed

+56
-16
lines changed

ext/SymEngineTermInterfaceExt.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module SymEngineTermInterfaceExt
22

33
import SymEngine
4-
import SymEngine: SymbolicType
54
import TermInterface
65

76

@@ -22,7 +21,27 @@ import TermInterface
2221
λ(::Val{:Csch}) = csch; λ(::Val{:Sech}) = sech; λ(::Val{:Coth}) = coth
2322
λ(::Val{:Asinh}) = asinh; λ(::Val{:Acosh}) = acosh; λ(::Val{:Atanh}) = atanh
2423
λ(::Val{:Acsch}) = acsch; λ(::Val{:Asech}) = asech; λ(::Val{:Acoth}) = acoth
25-
λ(::Val{:Gamma}) = gamma; λ(::Val{:Zeta}) = zeta; λ(::Val{:LambertW}) = lambertw
24+
λ(::Val{:ATan2}) = atan;
25+
λ(::Val{:Beta}) = SymEngine.SpecialFunctions.beta;
26+
λ(::Val{:Gamma}) = SymEngine.SpecialFunctions.gamma;
27+
λ(::Val{:PolyGamma}) = SymEngine.SpecialFunctions.polygamma;
28+
λ(::Val{:LogGamma}) = SymEngine.SpecialFunctions.loggamma;
29+
λ(::Val{:Erf}) = SymEngine.SpecialFunctions.erf;
30+
λ(::Val{:Erfc}) = SymEngine.SpecialFunctions.erfc;
31+
λ(::Val{:Zeta}) = SymEngine.SpecialFunctions.zeta;
32+
λ(::Val{:LambertW}) = SymEngine.SpecialFunctions.lambertw
33+
34+
35+
36+
const julia_operations = Vector{Any}(missing, length(SymEngine.symengine_classes))
37+
for (i,s) enumerate(SymEngine.symengine_classes)
38+
val = try
39+
λ(Val(s))
40+
catch err
41+
missing
42+
end
43+
julia_operations[i] = val
44+
end
2645

2746
#==
2847
Check if x represents an expression tree. If returns true, it will be assumed that operation(::T) and arguments(::T) methods are defined. Definining these three should allow use of SymbolicUtils.simplify on custom types. Optionally symtype(x) can be defined to return the expected type of the symbolic expression.
@@ -40,7 +59,7 @@ TermInterface.isexpr(x::SymEngine.SymbolicType) = TermInterface.iscall(x)
4059

4160
function TermInterface.operation(x::SymEngine.SymbolicType)
4261
TermInterface.iscall(x) || error("$(typeof(x)) doesn't have an operation!")
43-
return λ(x)
62+
return julia_operations[SymEngine.get_type(x) + 1]
4463
end
4564

4665
function TermInterface.arguments(x::SymEngine.SymbolicType)

src/mathfuns.jl

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,42 @@ for (meth, libnm, modu) in [
5151
(:acsch,:acsch,:Base),
5252
(:atanh,:atanh,:Base),
5353
(:acoth,:acoth,:Base),
54-
(:gamma,:gamma,:SpecialFunctions),
5554
(:log,:log,:Base),
5655
(:sqrt,:sqrt,:Base),
5756
(:exp,:exp,:Base),
5857
(:sign, :sign, :Base),
59-
(:eta,:dirichlet_eta,:SpecialFunctions),
60-
(:zeta,:zeta,:SpecialFunctions),
58+
(:ceil, :ceiling, :Base),
59+
(:floor, :floor, :Base)
6160
]
6261
eval(:(import $modu.$meth))
6362
IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm)
6463
end
64+
65+
for (meth, libnm, modu) in [
66+
(:gamma,:gamma,:SpecialFunctions),
67+
(:loggamma,:loggamma,:SpecialFunctions),
68+
(:eta,:dirichlet_eta,:SpecialFunctions),
69+
(:zeta,:zeta,:SpecialFunctions),
70+
(:erf, :erf, :SpecialFunctions),
71+
(:erfc, :erfc, :SpecialFunctions)
72+
]
73+
eval(:(import $modu.$meth))
74+
IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm)
75+
end
76+
77+
for (meth, libnm, modu) in [
78+
(:beta, :beta, :SpecialFunctions),
79+
(:polygamma, :polygamma, :SpecialFunctions),
80+
(:loggamma,:loggamma,:SpecialFunctions),
81+
]
82+
eval(:(import $modu.$meth))
83+
IMPLEMENT_TWO_ARG_FUNC(:($modu.$meth), libnm)
84+
end
85+
6586
Base.abs2(x::SymEngine.Basic) = abs(x)^2
6687

88+
89+
6790
if get_symbol(:basic_atan2) != C_NULL
6891
import Base.atan
6992
IMPLEMENT_TWO_ARG_FUNC(:(Base.atan), :atan2)

src/numerics.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,6 @@ end
262262
trunc(x::Basic, args...) = Basic(trunc(N(x), args...))
263263
trunc(::Type{T},x::Basic, args...) where {T <: Integer} = convert(T, trunc(x,args...))
264264

265-
ceil(x::Basic) = Basic(ceil(N(x)))
266-
ceil(::Type{T},x::Basic) where {T <: Integer} = convert(T, ceil(x))
267-
268-
floor(x::Basic) = Basic(floor(N(x)))
269-
floor(::Type{T},x::Basic) where {T <: Integer} = convert(T, floor(x))
270-
271265
round(x::Basic; kwargs...) = Basic(round(N(x); kwargs...))
272266
round(::Type{T},x::Basic; kwargs...) where {T <: Integer} = convert(T, round(x; kwargs...))
273267

src/subs.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,26 +62,30 @@ fn_map = Dict(
6262

6363
map_fn(key, fn_map) = haskey(fn_map, key) ? fn_map[key] : Symbol(lowercase(string(key)))
6464

65+
const julia_classes = map_fn.(symengine_classes, (fn_map,))
66+
get_julia_class(x::Basic) = julia_classes[get_type(x) + 1]
67+
Base.nameof(ex::Basic) = Symbol(toString(ex))
68+
6569
function _convert(::Type{Expr}, ex::Basic)
6670
fn = get_symengine_class(ex)
6771

6872
if fn == :Symbol
69-
return Symbol(toString(ex))
73+
return nameof(ex)
7074
elseif (fn in number_types) || (fn == :Constant)
7175
return N(ex)
7276
end
7377

7478
as = get_args(ex)
75-
76-
Expr(:call, map_fn(fn, fn_map), [_convert(Expr,a) for a in as]...)
79+
fn′ = get_julia_class(ex)
80+
Expr(:call, fn′, [_convert(Expr,a) for a in as]...)
7781
end
7882

7983

8084
function convert(::Type{Expr}, ex::Basic)
8185
fn = get_symengine_class(ex)
8286

8387
if fn == :Symbol
84-
return Expr(:call, :*, Symbol(toString(ex)), 1)
88+
return Expr(:call, :*, nameof(ex), 1)
8589
elseif (fn in number_types) || (fn == :Constant)
8690
return Expr(:call, :*, N(ex), 1)
8791
end

0 commit comments

Comments
 (0)