diff --git a/ext/SymEngineTermInterfaceExt.jl b/ext/SymEngineTermInterfaceExt.jl index 942a08e..b4c2749 100644 --- a/ext/SymEngineTermInterfaceExt.jl +++ b/ext/SymEngineTermInterfaceExt.jl @@ -1,7 +1,6 @@ module SymEngineTermInterfaceExt import SymEngine -import SymEngine: SymbolicType import TermInterface @@ -22,7 +21,27 @@ import TermInterface λ(::Val{:Csch}) = csch; λ(::Val{:Sech}) = sech; λ(::Val{:Coth}) = coth λ(::Val{:Asinh}) = asinh; λ(::Val{:Acosh}) = acosh; λ(::Val{:Atanh}) = atanh λ(::Val{:Acsch}) = acsch; λ(::Val{:Asech}) = asech; λ(::Val{:Acoth}) = acoth -λ(::Val{:Gamma}) = gamma; λ(::Val{:Zeta}) = zeta; λ(::Val{:LambertW}) = lambertw +λ(::Val{:ATan2}) = atan; +λ(::Val{:Beta}) = SymEngine.SpecialFunctions.beta; +λ(::Val{:Gamma}) = SymEngine.SpecialFunctions.gamma; +λ(::Val{:PolyGamma}) = SymEngine.SpecialFunctions.polygamma; +λ(::Val{:LogGamma}) = SymEngine.SpecialFunctions.loggamma; +λ(::Val{:Erf}) = SymEngine.SpecialFunctions.erf; +λ(::Val{:Erfc}) = SymEngine.SpecialFunctions.erfc; +λ(::Val{:Zeta}) = SymEngine.SpecialFunctions.zeta; +λ(::Val{:LambertW}) = SymEngine.SpecialFunctions.lambertw + + + +const julia_operations = Vector{Any}(missing, length(SymEngine.symengine_classes)) +for (i,s) ∈ enumerate(SymEngine.symengine_classes) + val = try + λ(Val(s)) + catch err + missing + end + julia_operations[i] = val +end #== Check if x represents an expression tree. If returns true, it will be assumed that operation(::T) and arguments(::T) methods are defined. Definining these three should allow use of SymbolicUtils.simplify on custom types. Optionally symtype(x) can be defined to return the expected type of the symbolic expression. @@ -40,7 +59,7 @@ TermInterface.isexpr(x::SymEngine.SymbolicType) = TermInterface.iscall(x) function TermInterface.operation(x::SymEngine.SymbolicType) TermInterface.iscall(x) || error("$(typeof(x)) doesn't have an operation!") - return λ(x) + return julia_operations[SymEngine.get_type(x) + 1] end function TermInterface.arguments(x::SymEngine.SymbolicType) diff --git a/src/mathfuns.jl b/src/mathfuns.jl index 8139129..81d074b 100644 --- a/src/mathfuns.jl +++ b/src/mathfuns.jl @@ -51,19 +51,42 @@ for (meth, libnm, modu) in [ (:acsch,:acsch,:Base), (:atanh,:atanh,:Base), (:acoth,:acoth,:Base), - (:gamma,:gamma,:SpecialFunctions), (:log,:log,:Base), (:sqrt,:sqrt,:Base), (:exp,:exp,:Base), (:sign, :sign, :Base), - (:eta,:dirichlet_eta,:SpecialFunctions), - (:zeta,:zeta,:SpecialFunctions), + (:ceil, :ceiling, :Base), + (:floor, :floor, :Base) ] eval(:(import $modu.$meth)) IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm) end + +for (meth, libnm, modu) in [ + (:gamma,:gamma,:SpecialFunctions), + (:loggamma,:loggamma,:SpecialFunctions), + (:eta,:dirichlet_eta,:SpecialFunctions), + (:zeta,:zeta,:SpecialFunctions), + (:erf, :erf, :SpecialFunctions), + (:erfc, :erfc, :SpecialFunctions) +] + eval(:(import $modu.$meth)) + IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm) +end + +for (meth, libnm, modu) in [ + (:beta, :beta, :SpecialFunctions), + (:polygamma, :polygamma, :SpecialFunctions), + (:loggamma,:loggamma,:SpecialFunctions), + ] + eval(:(import $modu.$meth)) + IMPLEMENT_TWO_ARG_FUNC(:($modu.$meth), libnm) +end + Base.abs2(x::SymEngine.Basic) = abs(x)^2 + + if get_symbol(:basic_atan2) != C_NULL import Base.atan IMPLEMENT_TWO_ARG_FUNC(:(Base.atan), :atan2) diff --git a/src/numerics.jl b/src/numerics.jl index c0b80d1..017ade1 100644 --- a/src/numerics.jl +++ b/src/numerics.jl @@ -262,12 +262,6 @@ end trunc(x::Basic, args...) = Basic(trunc(N(x), args...)) trunc(::Type{T},x::Basic, args...) where {T <: Integer} = convert(T, trunc(x,args...)) -ceil(x::Basic) = Basic(ceil(N(x))) -ceil(::Type{T},x::Basic) where {T <: Integer} = convert(T, ceil(x)) - -floor(x::Basic) = Basic(floor(N(x))) -floor(::Type{T},x::Basic) where {T <: Integer} = convert(T, floor(x)) - round(x::Basic; kwargs...) = Basic(round(N(x); kwargs...)) round(::Type{T},x::Basic; kwargs...) where {T <: Integer} = convert(T, round(x; kwargs...)) diff --git a/src/subs.jl b/src/subs.jl index 5f5b9e4..9e23f8a 100644 --- a/src/subs.jl +++ b/src/subs.jl @@ -62,18 +62,22 @@ fn_map = Dict( map_fn(key, fn_map) = haskey(fn_map, key) ? fn_map[key] : Symbol(lowercase(string(key))) +const julia_classes = map_fn.(symengine_classes, (fn_map,)) +get_julia_class(x::Basic) = julia_classes[get_type(x) + 1] +Base.nameof(ex::Basic) = Symbol(toString(ex)) + function _convert(::Type{Expr}, ex::Basic) fn = get_symengine_class(ex) if fn == :Symbol - return Symbol(toString(ex)) + return nameof(ex) elseif (fn in number_types) || (fn == :Constant) return N(ex) end as = get_args(ex) - - Expr(:call, map_fn(fn, fn_map), [_convert(Expr,a) for a in as]...) + fn′ = get_julia_class(ex) + Expr(:call, fn′, [_convert(Expr,a) for a in as]...) end @@ -81,7 +85,7 @@ function convert(::Type{Expr}, ex::Basic) fn = get_symengine_class(ex) if fn == :Symbol - return Expr(:call, :*, Symbol(toString(ex)), 1) + return Expr(:call, :*, nameof(ex), 1) elseif (fn in number_types) || (fn == :Constant) return Expr(:call, :*, N(ex), 1) end