Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/display.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
function toString(b::SymbolicType)
b = Basic(b)
if b.ptr == C_NULL
error("Trying to print an uninitialized SymEngine Basic variable.")
return "<Unitialized Basic value>"
end
a = ccall((:basic_str_julia, libsymengine), Cstring, (Ref{Basic}, ), b)
string = unsafe_string(a)
Expand Down
56 changes: 41 additions & 15 deletions src/mathfuns.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,43 @@
using SpecialFunctions

function IMPLEMENT_ONE_ARG_FUNC(meth, symnm; lib=:basic_)
function IMPLEMENT_ONE_ARG_FUNC(modu, meth, symnm; lib=:basic_)
methbang = Symbol(meth, "!")
if isa(modu, Symbol)
meth = :($modu.$meth)
end
@eval begin
function ($meth)(b::SymbolicType)
a = Basic()
($methbang)(a,b)
a
end
function ($methbang)(a::Basic, b::SymbolicType)
err_code = ccall(($(string(lib,symnm)), libsymengine), Cuint, (Ref{Basic}, Ref{Basic}), a, b)
throw_if_error(err_code, $(string(meth)))
return a
end
end
end

function IMPLEMENT_TWO_ARG_FUNC(meth, symnm; lib=:basic_)
function IMPLEMENT_TWO_ARG_FUNC(modu, meth, symnm; lib=:basic_)
methbang = Symbol(meth, "!")
if isa(modu, Symbol)
meth = :($modu.$meth)
end

@eval begin
function ($meth)(b1::SymbolicType, b2::Number)
a = Basic()
($methbang)(a,b1,b2)
a
end
function ($methbang)(a::Basic, b1::SymbolicType, b2::Number)
b1, b2 = promote(b1, b2)
err_code = ccall(($(string(lib,symnm)), libsymengine), Cuint, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b1, b2)
throw_if_error(err_code, $(string(meth)))
return a
end

end
end

Expand Down Expand Up @@ -59,7 +77,7 @@ for (meth, libnm, modu) in [
(:floor, :floor, :Base)
]
eval(:(import $modu.$meth))
IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm)
IMPLEMENT_ONE_ARG_FUNC(modu, meth, libnm)
end

for (meth, libnm, modu) in [
Expand All @@ -71,7 +89,7 @@ for (meth, libnm, modu) in [
(:erfc, :erfc, :SpecialFunctions)
]
eval(:(import $modu.$meth))
IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm)
IMPLEMENT_ONE_ARG_FUNC(modu, meth, libnm)
end

for (meth, libnm, modu) in [
Expand All @@ -80,23 +98,31 @@ for (meth, libnm, modu) in [
(:loggamma,:loggamma,:SpecialFunctions),
]
eval(:(import $modu.$meth))
IMPLEMENT_TWO_ARG_FUNC(:($modu.$meth), libnm)
IMPLEMENT_TWO_ARG_FUNC(modu, meth, libnm)
end

Base.abs2(x::SymEngine.Basic) = abs(x)^2

function abs2!(a::Basic, x::Basic)
abs!(a, x)
mul!(a, a, a)
a
end
function Base.abs2(x::Basic)
a = Basic()
abs2!(a, x)
a
end


if get_symbol(:basic_atan2) != C_NULL
import Base.atan
IMPLEMENT_TWO_ARG_FUNC(:(Base.atan), :atan2)
IMPLEMENT_TWO_ARG_FUNC(:Base, :atan, :atan2)
end

# export not import
for (meth, libnm) in [
(:lambertw,:lambertw), # in add-on packages, not base
]
IMPLEMENT_ONE_ARG_FUNC(meth, libnm)
IMPLEMENT_ONE_ARG_FUNC(nothing, meth, libnm)
eval(Expr(:export, meth))
end

Expand All @@ -118,7 +144,7 @@ for (meth, libnm) in [(:gcd, :gcd),
(:mod, :mod_f),
]
eval(:(import Base.$meth))
IMPLEMENT_TWO_ARG_FUNC(:(Base.$meth), libnm, lib=:ntheory_)
IMPLEMENT_TWO_ARG_FUNC(:Base, meth, libnm; lib=:ntheory_)
end

Base.binomial(n::Basic, k::Number) = binomial(N(n), N(k)) #ntheory_binomial seems wrong
Expand All @@ -129,18 +155,18 @@ Base.factorial(n::SymbolicType, k) = factorial(N(n), N(k))
## but not (:fibonacci,:fibonacci), (:lucas, :lucas) (Basic type is not the signature)
for (meth, libnm) in [(:nextprime,:nextprime)
]
IMPLEMENT_ONE_ARG_FUNC(meth, libnm, lib=:ntheory_)
IMPLEMENT_ONE_ARG_FUNC(nothing, meth, libnm, lib=:ntheory_)
eval(Expr(:export, meth))
end

"Return coefficient of `x^n` term, `x` a symbol"
function coeff(b::Basic, x, n)
c = Basic()
function coeff!(a::Basic, b::Basic, x, n)
out = ccall((:basic_coeff, libsymengine), Nothing,
(Ref{Basic},Ref{Basic},Ref{Basic},Ref{Basic}),
c,b,Basic(x), Basic(n))
c
a,b,Basic(x), Basic(n))
a
end
coeff(b::Basic, x, n) = coeff!(Basic(), b, x, n)

function Base.convert(::Type{CVecBasic}, x::Vector{T}) where T
vec = CVecBasic()
Expand Down
26 changes: 19 additions & 7 deletions src/mathops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@ end


## main ops
for (op, libnm) in ((:+, :add), (:-, :sub), (:*, :mul), (:/, :div), (://, :div), (:^, :pow))
for (op, libnm) in ((:+, :add), (:-, :sub), (:*, :mul), (:/, :div), (:^, :pow))
tup = (Base.Symbol("basic_$libnm"), libsymengine)
opbang = Symbol(libnm,:!)
@eval begin
function ($op)(b1::Basic, b2::Basic)
a = Basic()
function ($opbang)(a::Basic, b1::Basic, b2::Basic)
err_code = ccall($tup, Cuint, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b1, b2)
throw_if_error(err_code, $(string(libnm)))
return a
end
function ($op)(b1::Basic, b2::Basic)
a = Basic()
($opbang)(a, b1, b2)
return a
end
($op)(b1::BasicType, b2::BasicType) = ($op)(Basic(b1), Basic(b2))
end
end
Expand All @@ -30,16 +35,20 @@ end
# In contrast to other standard operations such as `+`, `*`, `-`, and `/`,
# Julia doesn't implement a general fallback of `//` for `Number`s promoting
# the input arguments. Thus, we implement this here explicitly.
Base.:(//)(b1::SymbolicType, b2::SymbolicType) = b1 / b2
Base.:(//)(b1::SymbolicType, b2::Number) = //(promote(b1, b2)...)
Base.:(//)(b1::Number, b2::SymbolicType) = //(promote(b1, b2)...)


function sum(v::CVecBasic)
a = Basic()
function sum!(a::Basic, v::CVecBasic)
err_code = ccall((:basic_add_vec, libsymengine), Cuint, (Ref{Basic}, Ptr{Cvoid}), a, v.ptr)
throw_if_error(err_code, "add_vec")
return a
end
function sum(v::CVecBasic)
a = Basic()
sum!(a, v)
end

+(b1::Basic, b2::Basic, b3::Basic, bs...) = sum(convert(CVecBasic, [b1, b2, b3, bs...]))
+(b1::Basic, b2::Basic, b3, bs...) = +(Basic(b1), Basic(b2), Basic(b3), bs...)
Expand All @@ -51,12 +60,15 @@ end
+(b1, b2, b3::Basic, bs...) = +(Basic(b1), Basic(b2), Basic(b3), bs...)


function prod(v::CVecBasic)
a = Basic()
function prod!(a::Basic, v::CVecBasic)
err_code = ccall((:basic_mul_vec, libsymengine), Cuint, (Ref{Basic}, Ptr{Cvoid}), a, v.ptr)
throw_if_error(err_code, "mul_vec")
return a
end
function prod(v::CVecBasic)
a = Basic()
prod!(a, v)
end

*(b1::Basic, b2::Basic, b3::Basic, bs::Vararg{Number, N}) where {N} = prod(convert(CVecBasic, [b1, b2, b3, bs...]))
*(b1::Basic, b2::Basic, b3::Number, bs::Vararg{Number, N}) where {N} = *(Basic(b1), Basic(b2), Basic(b3), bs...)
Expand Down
7 changes: 5 additions & 2 deletions src/numerics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,14 @@ end


## Conversions SymEngine -> Julia
function as_numer_denom(x::Basic)
a, b = Basic(), Basic()
function as_numer_denom!(a::Basic, b::Basic, x::Basic)
ccall((:basic_as_numer_denom, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b, x)
return a, b
end
function as_numer_denom(x::Basic)
a, b = Basic(), Basic()
as_numer_denom!(a,b,x)
end

as_numer_denom(x::BasicType) = as_numer_denom(Basic(x))
denominator(x::SymbolicType) = as_numer_denom(x)[2]
Expand Down
2 changes: 1 addition & 1 deletion src/simplify.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
IMPLEMENT_ONE_ARG_FUNC(:expand, :expand)
IMPLEMENT_ONE_ARG_FUNC(nothing, :expand, :expand)

if get_symbol(:basic_cse) != C_NULL
function cse(exprs...)
Expand Down
9 changes: 9 additions & 0 deletions src/subs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,26 @@ subs(ex, x=>1, y=>1) # ditto
"""
function subs(ex::T, var::S, val) where {T<:SymbolicType, S<:SymbolicType}
s = Basic()
subs!(s, ex, var, val)
end

function subs!(s::Basic, ex::T, var::S, val) where {T<:SymbolicType, S<:SymbolicType}
err_code = ccall((:basic_subs2, libsymengine), Cuint, (Ref{Basic}, Ref{Basic}, Ref{Basic}, Ref{Basic}), s, ex, var, val)
throw_if_error(err_code, ex)
return s
end

function subs(ex::T, d::CMapBasicBasic) where T<:SymbolicType
s = Basic()
subs!(s, ex, d)
end
function subs!(s::Basic, ex::T, d::CMapBasicBasic) where T<:SymbolicType
err_code = ccall((:basic_subs, libsymengine), Cuint, (Ref{Basic}, Ref{Basic}, Ptr{Cvoid}), s, ex, d.ptr)
throw_if_error(err_code, ex)
return s
end


subs(ex::T, d::AbstractDict) where {T<:SymbolicType} = subs(ex, CMapBasicBasic(d))
subs(ex::T, y::Tuple{S, Any}) where {T <: SymbolicType, S<:SymbolicType} = subs(ex, y[1], y[2])
subs(ex::T, y::Tuple{S, Any}, args...) where {T <: SymbolicType, S<:SymbolicType} = subs(subs(ex, y), args...)
Expand Down
29 changes: 19 additions & 10 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,20 +280,32 @@ end
" Return free symbols in an expression as a `Set`"
function free_symbols(ex::Basic)
syms = CSetBasic()
ccall((:basic_free_symbols, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, syms.ptr)
free_symbols!(syms, ex)
convert(Vector, syms)
end
function free_symbols!(syms, ex::Basic)
ccall((:basic_free_symbols, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, syms.ptr)
syms
end

free_symbols(ex::BasicType) = free_symbols(Basic(ex))

_flat(A) = mapreduce(x->isa(x,Array) ? _flat(x) : x, vcat, A, init=Basic[]) # from rosetta code example

free_symbols(exs::Array{T}) where {T<:SymbolicType} = unique(_flat([free_symbols(ex) for ex in exs]))
free_symbols(exs::Tuple) = unique(_flat([free_symbols(ex) for ex in exs]))

"Return function symbols in an expression as a `Set`"
function function_symbols(ex::Basic)
syms = CSetBasic()
ccall((:basic_function_symbols, libsymengine), Nothing, (Ptr{Cvoid}, Ref{Basic}), syms.ptr, ex)
function_symbols!(syms, ex)
convert(Vector, syms)
end
function function_symbols!(syms, ex::Basic)
ccall((:basic_function_symbols, libsymengine), Nothing, (Ptr{Cvoid}, Ref{Basic}), syms.ptr, ex)
syms
end

function_symbols(ex::BasicType) = function_symbols(Basic(ex))
function_symbols(exs::Array{T}) where {T<:SymbolicType} = unique(_flat([function_symbols(ex) for ex in exs]))
function_symbols(exs::Tuple) = unique(_flat([function_symbols(ex) for ex in exs]))
Expand All @@ -309,23 +321,20 @@ end
"Return arguments of a function call as a vector of `Basic` objects"
function get_args(ex::Basic)
args = CVecBasic()
ccall((:basic_get_args, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, args.ptr)
get_args!(args, ex)
convert(Vector, args)
end

function get_args!(args, ex::Basic)
ccall((:basic_get_args, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, args.ptr)
end

## so that Dicts will work
basic_hash(ex::Basic) = ccall((:basic_hash, libsymengine), UInt, (Ref{Basic}, ), ex)
# similar definition as in Base for general objects
Base.hash(ex::Basic, h::UInt) = Base.hash_uint(3h - basic_hash(ex))
Base.hash(ex::BasicType, h::UInt) = hash(Basic(ex), h)

function coeff(b::Basic, x::Basic, n::Basic)
c = Basic()
ccall((:basic_coeff, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}, Ref{Basic}, Ref{Basic}), c, b, x, n)
return c
end

coeff(b::Basic, x::Basic) = coeff(b, x, one(Basic))

function Serialization.serialize(s::Serialization.AbstractSerializer, m::Basic)
Serialization.serialize_type(s, typeof(m))
Expand Down
4 changes: 3 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ let
@vars w
end
@test_throws UndefVarError isdefined(w)
@test_throws Exception show(Basic())
@test repr(Basic()) == "<Unitialized Basic value>"

# test @vars constructions
@vars a, b[0:4], c(), d=>"D"
Expand Down Expand Up @@ -354,3 +354,5 @@ end
close(iobuf)
@test deserialized == data
end

VERSION >= v"1.9.0" && include("test-allocations.jl")
18 changes: 18 additions & 0 deletions test/test-allocations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# test for allocations
@vars x y
a = Basic()
@testset "non-allocating(ish) methods" begin
sin(x), cos(x), abs(x)
x^x, x + x, x*x, x-x, x/x # warm up

@test (@allocations SymEngine.sin!(a,x)) == 0
@test (@allocations SymEngine.cos!(a,x)) == 0
@test (@allocations SymEngine.pow!(a,x,x)) == 0

# still allocates 1 (or 2)
@test (@allocations SymEngine.add!(a,x,y)) < (@allocations x+y)
@test (@allocations SymEngine.sub!(a,x,y)) < (@allocations x-y)
@test (@allocations SymEngine.mul!(a,x,y)) < (@allocations x*y)
@test (@allocations SymEngine.div!(a,x,y)) < (@allocations x/y)
@test (@allocations SymEngine.abs2!(a,x)) < (@allocations abs2(x))
end
Loading