Skip to content

Commit b177158

Browse files
committed
cleanup
1 parent 4a07f2c commit b177158

File tree

1 file changed

+64
-9
lines changed

1 file changed

+64
-9
lines changed

ext/SymEngineSymbolicUtilsExt.jl

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,34 @@ module SymEngineSymbolicUtilsExt
33
using SymEngine
44
using SymbolicUtils
55
import SymEngine: SymbolicType
6+
7+
#
8+
function is_number(a::SymEngine.Basic)
9+
cls = SymEngine.get_symengine_class(a)
10+
any(==(cls), SymEngine.number_types) && return true
11+
false
12+
end
13+
14+
15+
λ(x::SymEngine.SymbolicType) = λ(Val(SymEngine.get_symengine_class(x)))
16+
λ(::Val{T}) where {T} = getfield(Main, Symbol(lowercase(string(T))))
17+
18+
λ(::Val{:Add}) = +; λ(::Val{:Sub}) = -
19+
λ(::Val{:Mul}) = *; λ(::Val{:Div}) = /
20+
λ(::Val{:Pow}) = ^
21+
λ(::Val{:re}) = real; λ(::Val{:im}) = imag
22+
λ(::Val{:Abs}) = abs
23+
λ(::Val{:Log}) = log
24+
λ(::Val{:Sin}) = sin; λ(::Val{:Cos}) = cos; λ(::Val{:Tan}) = tan
25+
λ(::Val{:Csc}) = csc; λ(::Val{:Sec}) = sec; λ(::Val{:Cot}) = cot
26+
λ(::Val{:Asin}) = asin; λ(::Val{:Acos}) = acos; λ(::Val{:Atan}) = atan
27+
λ(::Val{:Acsc}) = acsc; λ(::Val{:Asec}) = asec; λ(::Val{:Acot}) = acot
28+
λ(::Val{:Sinh}) = sinh; λ(::Val{:Cosh}) = cosh; λ(::Val{:Tanh}) = tanh
29+
λ(::Val{:Csch}) = csch; λ(::Val{:Sech}) = sech; λ(::Val{:Coth}) = coth
30+
λ(::Val{:Asinh}) = asinh; λ(::Val{:Acosh}) = acosh; λ(::Val{:Atanh}) = atanh
31+
λ(::Val{:Acsch}) = acsch; λ(::Val{:Asech}) = asech; λ(::Val{:Acoth}) = acoth
32+
λ(::Val{:Gamma}) = gamma; λ(::Val{:Zeta}) = zeta; λ(::Val{:LambertW}) = lambertw
33+
634
#==
735
Check if x represents an expression tree. If returns true, it will be assumed that operation(::T) and arguments(::T) methods are defined. Definining these three should allow use of SymbolicUtils.simplify on custom types. Optionally symtype(x) can be defined to return the expected type of the symbolic expression.
836
==#
@@ -13,20 +41,17 @@ function SymbolicUtils.istree(x::SymEngine.SymbolicType)
1341
return true
1442
end
1543

16-
#==
17-
f x is a term as defined by istree(x), exprhead(x) must return a symbol, corresponding to the head of the Expr most similar to the term x. If x represents a function call, for example, the exprhead is :call. If x represents an indexing operation, such as arr[i], then exprhead is :ref. Note that exprhead is different from operation and both functions should be defined correctly in order to let other packages provide code generation and pattern matching features.
18-
function SymbolicUtils.exprhead(x::SymEngine.SymbolicType) # deprecated
19-
:call # this is not right
20-
end
21-
==#
44+
SymbolicUtils.issym(x::SymEngine.SymbolicType) = SymEngine.get_symengine_class(x) == :Symbol
45+
Base.nameof(x::SymEngine.SymbolicType) = Symbol(x)
46+
47+
# no metadata(x), metadata(x, data)
2248

2349
#==
2450
Returns the head (a function object) performed by an expression tree. Called only if istree(::T) is true. Part of the API required for simplify to work. Other required methods are arguments and istree
2551
==#
2652
function SymbolicUtils.operation(x::SymEngine.SymbolicType)
27-
@assert istree(x)
28-
nm = SymEngine.map_fn(SymEngine.get_symengine_class(x), SymEngine.fn_map)
29-
return getfield(Main, nm)
53+
istree(x) || error("$(typeof(x)) doesn't have an operation!")
54+
return λ(x)
3055
end
3156

3257

@@ -45,5 +70,35 @@ function SymbolicUtils.similarterm(t::SymEngine.SymbolicType, f, args, symtype=n
4570
f(args...) # default
4671
end
4772

73+
# Needed for some simplification routines
74+
# a total order <ₑ
75+
import SymbolicUtils: <ₑ, isterm, isadd, ismul, issym, cmp_mul_adds, cmp_term_term
76+
function SymbolicUtils.:<ₑ(a::SymEngine.Basic, b::SymEngine.Basic)
77+
if isterm(a) && !isterm(b)
78+
return false
79+
elseif isterm(b) && !isterm(a)
80+
return true
81+
elseif (isadd(a) || ismul(a)) && (isadd(b) || ismul(b))
82+
return cmp_mul_adds(a, b)
83+
elseif issym(a) && issym(b)
84+
nameof(a) < nameof(b)
85+
elseif !istree(a) && !istree(b)
86+
T = typeof(a)
87+
S = typeof(b)
88+
if T == S
89+
is_number(a) && is_number(b) && return N(a) < N(b)
90+
return hash(a) < hash(b)
91+
else
92+
return name(T) < nameof(S)
93+
end
94+
#return T===S ? (T <: Number ? isless(a, b) : hash(a) < hash(b)) : nameof(T) < nameof(S)
95+
elseif istree(b) && !istree(a)
96+
return true
97+
elseif istree(a) && istree(b)
98+
return cmp_term_term(a,b)
99+
else
100+
return !(b <ₑ a)
101+
end
102+
end
48103

49104
end

0 commit comments

Comments
 (0)