Skip to content

Commit 4ae979c

Browse files
authored
Merge pull request #280 from jverzani/sign_has_symbol
add sign, has_symbol
2 parents 58f6665 + b223c02 commit 4ae979c

File tree

4 files changed

+35
-4
lines changed

4 files changed

+35
-4
lines changed

src/mathfuns.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ for (meth, libnm, modu) in [
5555
(:log,:log,:Base),
5656
(:sqrt,:sqrt,:Base),
5757
(:exp,:exp,:Base),
58+
(:sign, :sign, :Base),
5859
(:eta,:dirichlet_eta,:SpecialFunctions),
5960
(:zeta,:zeta,:SpecialFunctions),
6061
]

src/subs.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,12 @@ subs(ex::T, d::Pair...) where {T <: SymbolicType} = subs(ex, [(p.first, p.second
3838
## Allow an expression to be called, as with ex(2). When there is more than one symbol, one can rely on order of `free_symbols` or
3939
## be explicit by passing in pairs : `ex(x=>1, y=>2)` or a dict `ex(Dict(x=>1, y=>2))`.
4040
function (ex::Basic)(args...)
41-
xs = free_symbols(ex)
42-
subs(ex, collect(zip(xs, args))...)
41+
xs = free_symbols(ex)
42+
isempty(xs) && return ex
43+
subs(ex, collect(zip(xs, args))...)
4344
end
4445
(ex::Basic)(x::AbstractDict) = subs(ex, x)
45-
(ex::Basic)(x::Pair...) = subs(ex, x...)
46+
(ex::Basic)(x::Pair, xs::Pair...) = subs(ex, x, xs...)
4647

4748

4849
## Lambdify

src/types.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,18 @@ BasicTrigFunction = Union{[SymEngine.BasicType{Val{i}} for i in trig_types]...}
232232

233233

234234
###
235+
"Is expression a symbol"
236+
is_symbol(x::Basic) = is_symbol(BasicType(x))
237+
is_symbol(x::BasicType{Val{:Symbol}}) = true
238+
is_symbol(x::BasicType) = false
239+
240+
241+
"Does expression contain the symbol"
242+
function has_symbol(ex::Basic, x::Basic)
243+
is_symbol(x) || throw(ArgumentError("Not a symbol"))
244+
res = ccall((:basic_has_symbol, libsymengine), Cuint, (Ref{Basic},Ref{Basic}), ex, x)
245+
Bool(convert(Int, res))
246+
end
235247

236248

237249
" Return free symbols in an expression as a `Set`"
@@ -305,4 +317,3 @@ function Serialization.deserialize(s::Serialization.AbstractSerializer, ::Type{B
305317
throw_if_error(res)
306318
return a
307319
end
308-

test/runtests.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,24 @@ A = [x 2; x 1]
185185
x,y,z = symbols("x y z")
186186
@test length(SymEngine.free_symbols([x*y, y,z])) == 3
187187

188+
# is/has/free symbol(s)
189+
@vars x y z
190+
@test SymEngine.is_symbol(x)
191+
@test !SymEngine.is_symbol(x(2))
192+
@test !SymEngine.is_symbol(x^2)
193+
@test SymEngine.has_symbol(x^2, x)
194+
@test SymEngine.has_symbol(sin(sin(sin(x))), x)
195+
@test !SymEngine.has_symbol(x^2, y)
196+
@test Set(free_symbols(x*y)) == Set([x,y])
197+
@test Set(free_symbols(x*y^z)) != Set([x,y])
198+
199+
# call without specifying variables
200+
@vars x y
201+
z = x(2)
202+
@test x(2) == 2
203+
@test (x*y^2)(1,2) == subs(x*y^2, x=>1, y=>2) == (x*y^2)(x=>1, y=>2)
204+
@test z() == 2
205+
@test z(1) == 2
188206

189207
## check that callable symengine expressions can be used as functions for duck-typed functions
190208
@vars x

0 commit comments

Comments
 (0)