Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions ext/SymEngineTermInterfaceExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module SymEngineTermInterfaceExt

import SymEngine
import SymEngine: SymbolicType
import TermInterface


Expand All @@ -22,7 +21,27 @@ import TermInterface
λ(::Val{:Csch}) = csch; λ(::Val{:Sech}) = sech; λ(::Val{:Coth}) = coth
λ(::Val{:Asinh}) = asinh; λ(::Val{:Acosh}) = acosh; λ(::Val{:Atanh}) = atanh
λ(::Val{:Acsch}) = acsch; λ(::Val{:Asech}) = asech; λ(::Val{:Acoth}) = acoth
λ(::Val{:Gamma}) = gamma; λ(::Val{:Zeta}) = zeta; λ(::Val{:LambertW}) = lambertw
λ(::Val{:ATan2}) = atan;
λ(::Val{:Beta}) = SymEngine.SpecialFunctions.beta;
λ(::Val{:Gamma}) = SymEngine.SpecialFunctions.gamma;
λ(::Val{:PolyGamma}) = SymEngine.SpecialFunctions.polygamma;
λ(::Val{:LogGamma}) = SymEngine.SpecialFunctions.loggamma;
λ(::Val{:Erf}) = SymEngine.SpecialFunctions.erf;
λ(::Val{:Erfc}) = SymEngine.SpecialFunctions.erfc;
λ(::Val{:Zeta}) = SymEngine.SpecialFunctions.zeta;
λ(::Val{:LambertW}) = SymEngine.SpecialFunctions.lambertw



const julia_operations = Vector{Any}(missing, length(SymEngine.symengine_classes))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is const and gets changed afterwards?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Julia, the binding is constant, but not its components. (To be honest, I'm not sure this gives any performance improvement here.)

for (i,s) ∈ enumerate(SymEngine.symengine_classes)
val = try
λ(Val(s))
catch err
missing
end
julia_operations[i] = val
end

#==
Check if x represents an expression tree. If returns true, it will be assumed that operation(::T) and arguments(::T) methods are defined. Definining these three should allow use of SymbolicUtils.simplify on custom types. Optionally symtype(x) can be defined to return the expected type of the symbolic expression.
Expand All @@ -40,7 +59,7 @@ TermInterface.isexpr(x::SymEngine.SymbolicType) = TermInterface.iscall(x)

function TermInterface.operation(x::SymEngine.SymbolicType)
TermInterface.iscall(x) || error("$(typeof(x)) doesn't have an operation!")
return λ(x)
return julia_operations[SymEngine.get_type(x) + 1]
end

function TermInterface.arguments(x::SymEngine.SymbolicType)
Expand Down
29 changes: 26 additions & 3 deletions src/mathfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,42 @@ for (meth, libnm, modu) in [
(:acsch,:acsch,:Base),
(:atanh,:atanh,:Base),
(:acoth,:acoth,:Base),
(:gamma,:gamma,:SpecialFunctions),
(:log,:log,:Base),
(:sqrt,:sqrt,:Base),
(:exp,:exp,:Base),
(:sign, :sign, :Base),
(:eta,:dirichlet_eta,:SpecialFunctions),
(:zeta,:zeta,:SpecialFunctions),
(:ceil, :ceiling, :Base),
(:floor, :floor, :Base)
]
eval(:(import $modu.$meth))
IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm)
end

for (meth, libnm, modu) in [
(:gamma,:gamma,:SpecialFunctions),
(:loggamma,:loggamma,:SpecialFunctions),
(:eta,:dirichlet_eta,:SpecialFunctions),
(:zeta,:zeta,:SpecialFunctions),
(:erf, :erf, :SpecialFunctions),
(:erfc, :erfc, :SpecialFunctions)
]
eval(:(import $modu.$meth))
IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm)
end

for (meth, libnm, modu) in [
(:beta, :beta, :SpecialFunctions),
(:polygamma, :polygamma, :SpecialFunctions),
(:loggamma,:loggamma,:SpecialFunctions),
]
eval(:(import $modu.$meth))
IMPLEMENT_TWO_ARG_FUNC(:($modu.$meth), libnm)
end

Base.abs2(x::SymEngine.Basic) = abs(x)^2



if get_symbol(:basic_atan2) != C_NULL
import Base.atan
IMPLEMENT_TWO_ARG_FUNC(:(Base.atan), :atan2)
Expand Down
6 changes: 0 additions & 6 deletions src/numerics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,6 @@ end
trunc(x::Basic, args...) = Basic(trunc(N(x), args...))
trunc(::Type{T},x::Basic, args...) where {T <: Integer} = convert(T, trunc(x,args...))

ceil(x::Basic) = Basic(ceil(N(x)))
ceil(::Type{T},x::Basic) where {T <: Integer} = convert(T, ceil(x))

floor(x::Basic) = Basic(floor(N(x)))
floor(::Type{T},x::Basic) where {T <: Integer} = convert(T, floor(x))

round(x::Basic; kwargs...) = Basic(round(N(x); kwargs...))
round(::Type{T},x::Basic; kwargs...) where {T <: Integer} = convert(T, round(x; kwargs...))

Expand Down
12 changes: 8 additions & 4 deletions src/subs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,26 +62,30 @@ fn_map = Dict(

map_fn(key, fn_map) = haskey(fn_map, key) ? fn_map[key] : Symbol(lowercase(string(key)))

const julia_classes = map_fn.(symengine_classes, (fn_map,))
get_julia_class(x::Basic) = julia_classes[get_type(x) + 1]
Base.nameof(ex::Basic) = Symbol(toString(ex))

function _convert(::Type{Expr}, ex::Basic)
fn = get_symengine_class(ex)

if fn == :Symbol
return Symbol(toString(ex))
return nameof(ex)
elseif (fn in number_types) || (fn == :Constant)
return N(ex)
end

as = get_args(ex)

Expr(:call, map_fn(fn, fn_map), [_convert(Expr,a) for a in as]...)
fn′ = get_julia_class(ex)
Expr(:call, fn′, [_convert(Expr,a) for a in as]...)
end


function convert(::Type{Expr}, ex::Basic)
fn = get_symengine_class(ex)

if fn == :Symbol
return Expr(:call, :*, Symbol(toString(ex)), 1)
return Expr(:call, :*, nameof(ex), 1)
elseif (fn in number_types) || (fn == :Constant)
return Expr(:call, :*, N(ex), 1)
end
Expand Down
Loading