diff --git a/src/SymEngine.jl b/src/SymEngine.jl index c5d6693..8a1ebaa 100644 --- a/src/SymEngine.jl +++ b/src/SymEngine.jl @@ -21,8 +21,8 @@ const have_mpc = have_component("mpc") const libversion = get_libversion() include("exceptions.jl") -include("types.jl") include("ctypes.jl") +include("types.jl") include("decl.jl") include("display.jl") include("mathops.jl") diff --git a/src/ctypes.jl b/src/ctypes.jl index e043afc..d125d51 100644 --- a/src/ctypes.jl +++ b/src/ctypes.jl @@ -1,4 +1,21 @@ # types from SymEngine to Julia + +## Basic +## Hold a reference to a SymEngine object +mutable struct Basic <: Number + ptr::Ptr{Cvoid} + function Basic() + z = new(C_NULL) + ccall((:basic_new_stack, libsymengine), Nothing, (Ref{Basic}, ), z) + finalizer(basic_free, z) + return z + end + function Basic(v::Ptr{Cvoid}) + z = new(v) + return z + end +end + ## CSetBasic mutable struct CSetBasic ptr::Ptr{Cvoid} @@ -160,7 +177,7 @@ function CDenseMatrix(x::Array{T, 2}) where T end -function Base.convert(::Type{Matrix}, x::CDenseMatrix) +function Base.convert(::Type{Matrix}, x::CDenseMatrix) m,n = Base.size(x) [x[i,j] for i in 1:m, j in 1:n] end diff --git a/src/display.jl b/src/display.jl index 4340c41..061a2fd 100644 --- a/src/display.jl +++ b/src/display.jl @@ -2,7 +2,7 @@ function toString(b::SymbolicType) b = Basic(b) if b.ptr == C_NULL - error("Trying to print an uninitialized SymEngine Basic variable.") + return "" end a = ccall((:basic_str_julia, libsymengine), Cstring, (Ref{Basic}, ), b) string = unsafe_string(a) diff --git a/src/mathfuns.jl b/src/mathfuns.jl index 81d074b..b282480 100644 --- a/src/mathfuns.jl +++ b/src/mathfuns.jl @@ -1,9 +1,17 @@ using SpecialFunctions -function IMPLEMENT_ONE_ARG_FUNC(meth, symnm; lib=:basic_) +function IMPLEMENT_ONE_ARG_FUNC(modu, meth, symnm; lib=:basic_) + methbang = Symbol(meth, "!") + if isa(modu, Symbol) + meth = :($modu.$meth) + end @eval begin function ($meth)(b::SymbolicType) a = Basic() + ($methbang)(a,b) + a + end + function ($methbang)(a::Basic, b::SymbolicType) err_code = ccall(($(string(lib,symnm)), libsymengine), Cuint, (Ref{Basic}, Ref{Basic}), a, b) throw_if_error(err_code, $(string(meth))) return a @@ -11,15 +19,25 @@ function IMPLEMENT_ONE_ARG_FUNC(meth, symnm; lib=:basic_) end end -function IMPLEMENT_TWO_ARG_FUNC(meth, symnm; lib=:basic_) +function IMPLEMENT_TWO_ARG_FUNC(modu, meth, symnm; lib=:basic_) + methbang = Symbol(meth, "!") + if isa(modu, Symbol) + meth = :($modu.$meth) + end + @eval begin function ($meth)(b1::SymbolicType, b2::Number) a = Basic() + ($methbang)(a,b1,b2) + a + end + function ($methbang)(a::Basic, b1::SymbolicType, b2::Number) b1, b2 = promote(b1, b2) err_code = ccall(($(string(lib,symnm)), libsymengine), Cuint, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b1, b2) throw_if_error(err_code, $(string(meth))) return a end + end end @@ -59,7 +77,7 @@ for (meth, libnm, modu) in [ (:floor, :floor, :Base) ] eval(:(import $modu.$meth)) - IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm) + IMPLEMENT_ONE_ARG_FUNC(modu, meth, libnm) end for (meth, libnm, modu) in [ @@ -71,7 +89,7 @@ for (meth, libnm, modu) in [ (:erfc, :erfc, :SpecialFunctions) ] eval(:(import $modu.$meth)) - IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm) + IMPLEMENT_ONE_ARG_FUNC(modu, meth, libnm) end for (meth, libnm, modu) in [ @@ -80,23 +98,31 @@ for (meth, libnm, modu) in [ (:loggamma,:loggamma,:SpecialFunctions), ] eval(:(import $modu.$meth)) - IMPLEMENT_TWO_ARG_FUNC(:($modu.$meth), libnm) + IMPLEMENT_TWO_ARG_FUNC(modu, meth, libnm) end -Base.abs2(x::SymEngine.Basic) = abs(x)^2 - +function abs2!(a::Basic, x::Basic) + abs!(a, x) + mul!(a, a, a) + a +end +function Base.abs2(x::Basic) + a = Basic() + abs2!(a, x) + a +end if get_symbol(:basic_atan2) != C_NULL import Base.atan - IMPLEMENT_TWO_ARG_FUNC(:(Base.atan), :atan2) + IMPLEMENT_TWO_ARG_FUNC(:Base, :atan, :atan2) end # export not import for (meth, libnm) in [ (:lambertw,:lambertw), # in add-on packages, not base ] - IMPLEMENT_ONE_ARG_FUNC(meth, libnm) + IMPLEMENT_ONE_ARG_FUNC(nothing, meth, libnm) eval(Expr(:export, meth)) end @@ -118,7 +144,7 @@ for (meth, libnm) in [(:gcd, :gcd), (:mod, :mod_f), ] eval(:(import Base.$meth)) - IMPLEMENT_TWO_ARG_FUNC(:(Base.$meth), libnm, lib=:ntheory_) + IMPLEMENT_TWO_ARG_FUNC(:Base, meth, libnm; lib=:ntheory_) end Base.binomial(n::Basic, k::Number) = binomial(N(n), N(k)) #ntheory_binomial seems wrong @@ -129,18 +155,18 @@ Base.factorial(n::SymbolicType, k) = factorial(N(n), N(k)) ## but not (:fibonacci,:fibonacci), (:lucas, :lucas) (Basic type is not the signature) for (meth, libnm) in [(:nextprime,:nextprime) ] - IMPLEMENT_ONE_ARG_FUNC(meth, libnm, lib=:ntheory_) + IMPLEMENT_ONE_ARG_FUNC(nothing, meth, libnm, lib=:ntheory_) eval(Expr(:export, meth)) end "Return coefficient of `x^n` term, `x` a symbol" -function coeff(b::Basic, x, n) - c = Basic() +function coeff!(a::Basic, b::Basic, x, n) out = ccall((:basic_coeff, libsymengine), Nothing, (Ref{Basic},Ref{Basic},Ref{Basic},Ref{Basic}), - c,b,Basic(x), Basic(n)) - c + a,b,Basic(x), Basic(n)) + a end +coeff(b::Basic, x, n) = coeff!(Basic(), b, x, n) function Base.convert(::Type{CVecBasic}, x::Vector{T}) where T vec = CVecBasic() diff --git a/src/mathops.jl b/src/mathops.jl index da9815c..9a0611a 100644 --- a/src/mathops.jl +++ b/src/mathops.jl @@ -8,15 +8,20 @@ end ## main ops -for (op, libnm) in ((:+, :add), (:-, :sub), (:*, :mul), (:/, :div), (://, :div), (:^, :pow)) +for (op, libnm) in ((:+, :add), (:-, :sub), (:*, :mul), (:/, :div), (:^, :pow)) tup = (Base.Symbol("basic_$libnm"), libsymengine) + opbang = Symbol(libnm,:!) @eval begin - function ($op)(b1::Basic, b2::Basic) - a = Basic() + function ($opbang)(a::Basic, b1::Basic, b2::Basic) err_code = ccall($tup, Cuint, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b1, b2) throw_if_error(err_code, $(string(libnm))) return a end + function ($op)(b1::Basic, b2::Basic) + a = Basic() + ($opbang)(a, b1, b2) + return a + end ($op)(b1::BasicType, b2::BasicType) = ($op)(Basic(b1), Basic(b2)) end end @@ -30,16 +35,20 @@ end # In contrast to other standard operations such as `+`, `*`, `-`, and `/`, # Julia doesn't implement a general fallback of `//` for `Number`s promoting # the input arguments. Thus, we implement this here explicitly. +Base.:(//)(b1::SymbolicType, b2::SymbolicType) = b1 / b2 Base.:(//)(b1::SymbolicType, b2::Number) = //(promote(b1, b2)...) Base.:(//)(b1::Number, b2::SymbolicType) = //(promote(b1, b2)...) -function sum(v::CVecBasic) - a = Basic() +function sum!(a::Basic, v::CVecBasic) err_code = ccall((:basic_add_vec, libsymengine), Cuint, (Ref{Basic}, Ptr{Cvoid}), a, v.ptr) throw_if_error(err_code, "add_vec") return a end +function sum(v::CVecBasic) + a = Basic() + sum!(a, v) +end +(b1::Basic, b2::Basic, b3::Basic, bs...) = sum(convert(CVecBasic, [b1, b2, b3, bs...])) +(b1::Basic, b2::Basic, b3, bs...) = +(Basic(b1), Basic(b2), Basic(b3), bs...) @@ -51,12 +60,15 @@ end +(b1, b2, b3::Basic, bs...) = +(Basic(b1), Basic(b2), Basic(b3), bs...) -function prod(v::CVecBasic) - a = Basic() +function prod!(a::Basic, v::CVecBasic) err_code = ccall((:basic_mul_vec, libsymengine), Cuint, (Ref{Basic}, Ptr{Cvoid}), a, v.ptr) throw_if_error(err_code, "mul_vec") return a end +function prod(v::CVecBasic) + a = Basic() + prod!(a, v) +end *(b1::Basic, b2::Basic, b3::Basic, bs::Vararg{Number, N}) where {N} = prod(convert(CVecBasic, [b1, b2, b3, bs...])) *(b1::Basic, b2::Basic, b3::Number, bs::Vararg{Number, N}) where {N} = *(Basic(b1), Basic(b2), Basic(b3), bs...) diff --git a/src/numerics.jl b/src/numerics.jl index 52734fa..74ea3a3 100644 --- a/src/numerics.jl +++ b/src/numerics.jl @@ -131,6 +131,7 @@ end ## deprecate N(::BasicType) N(b::BasicType{T}) where {T} = N(convert(Basic, b), T) +## Conversions SymEngine -> Julia ## define convert(T, x) methods leveraging N() when needed function convert(::Type{Float64}, x::Basic) is_a_RealDouble(x) && return _convert(Cdouble, x) @@ -170,11 +171,14 @@ Base.Real(x::Basic) = convert(Real, x) ## Rational -- p/q parts -function as_numer_denom(x::Basic) - a, b = Basic(), Basic() +function as_numer_denom!(a::Basic, b::Basic, x::Basic) ccall((:basic_as_numer_denom, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b, x) return a, b end +function as_numer_denom(x::Basic) + a, b = Basic(), Basic() + as_numer_denom!(a,b,x) +end as_numer_denom(x::BasicType) = as_numer_denom(Basic(x)) denominator(x::SymbolicType) = as_numer_denom(x)[2] diff --git a/src/simplify.jl b/src/simplify.jl index f8a1954..656c829 100644 --- a/src/simplify.jl +++ b/src/simplify.jl @@ -1,4 +1,4 @@ -IMPLEMENT_ONE_ARG_FUNC(:expand, :expand) +IMPLEMENT_ONE_ARG_FUNC(nothing, :expand, :expand) if get_symbol(:basic_cse) != C_NULL function cse(exprs...) diff --git a/src/subs.jl b/src/subs.jl index 9e23f8a..0d4d39e 100644 --- a/src/subs.jl +++ b/src/subs.jl @@ -18,17 +18,26 @@ subs(ex, x=>1, y=>1) # ditto """ function subs(ex::T, var::S, val) where {T<:SymbolicType, S<:SymbolicType} s = Basic() + subs!(s, ex, var, val) +end + +function subs!(s::Basic, ex::T, var::S, val) where {T<:SymbolicType, S<:SymbolicType} err_code = ccall((:basic_subs2, libsymengine), Cuint, (Ref{Basic}, Ref{Basic}, Ref{Basic}, Ref{Basic}), s, ex, var, val) throw_if_error(err_code, ex) return s end + function subs(ex::T, d::CMapBasicBasic) where T<:SymbolicType s = Basic() + subs!(s, ex, d) +end +function subs!(s::Basic, ex::T, d::CMapBasicBasic) where T<:SymbolicType err_code = ccall((:basic_subs, libsymengine), Cuint, (Ref{Basic}, Ref{Basic}, Ptr{Cvoid}), s, ex, d.ptr) throw_if_error(err_code, ex) return s end + subs(ex::T, d::AbstractDict) where {T<:SymbolicType} = subs(ex, CMapBasicBasic(d)) subs(ex::T, y::Tuple{S, Any}) where {T <: SymbolicType, S<:SymbolicType} = subs(ex, y[1], y[2]) subs(ex::T, y::Tuple{S, Any}, args...) where {T <: SymbolicType, S<:SymbolicType} = subs(subs(ex, y), args...) diff --git a/src/types.jl b/src/types.jl index 0d34bb3..86c6fb2 100644 --- a/src/types.jl +++ b/src/types.jl @@ -10,20 +10,6 @@ ## ## To control dispatch, one might have `N(b::Basic) = N(BasicType(b))` and then define `N` for types of interest -## Hold a reference to a SymEngine object -mutable struct Basic <: Number - ptr::Ptr{Cvoid} - function Basic() - z = new(C_NULL) - ccall((:basic_new_stack, libsymengine), Nothing, (Ref{Basic}, ), z) - finalizer(basic_free, z) - return z - end - function Basic(v::Ptr{Cvoid}) - z = new(v) - return z - end -end basic_free(b::Basic) = ccall((:basic_free_stack, libsymengine), Nothing, (Ref{Basic}, ), b) @@ -293,20 +279,32 @@ end " Return free symbols in an expression as a `Set`" function free_symbols(ex::Basic) syms = CSetBasic() - ccall((:basic_free_symbols, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, syms.ptr) + free_symbols!(syms, ex) convert(Vector, syms) end +function free_symbols!(syms::CSetBasic, ex::Basic) + ccall((:basic_free_symbols, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, syms.ptr) + syms +end + free_symbols(ex::BasicType) = free_symbols(Basic(ex)) + _flat(A) = mapreduce(x->isa(x,Array) ? _flat(x) : x, vcat, A, init=Basic[]) # from rosetta code example + free_symbols(exs::Array{T}) where {T<:SymbolicType} = unique(_flat([free_symbols(ex) for ex in exs])) free_symbols(exs::Tuple) = unique(_flat([free_symbols(ex) for ex in exs])) "Return function symbols in an expression as a `Set`" function function_symbols(ex::Basic) syms = CSetBasic() - ccall((:basic_function_symbols, libsymengine), Nothing, (Ptr{Cvoid}, Ref{Basic}), syms.ptr, ex) + function_symbols!(syms, ex) convert(Vector, syms) end +function function_symbols!(syms::CSetBasic, ex::Basic) + ccall((:basic_function_symbols, libsymengine), Nothing, (Ptr{Cvoid}, Ref{Basic}), syms.ptr, ex) + syms +end + function_symbols(ex::BasicType) = function_symbols(Basic(ex)) function_symbols(exs::Array{T}) where {T<:SymbolicType} = unique(_flat([function_symbols(ex) for ex in exs])) function_symbols(exs::Tuple) = unique(_flat([function_symbols(ex) for ex in exs])) @@ -322,23 +320,20 @@ end "Return arguments of a function call as a vector of `Basic` objects" function get_args(ex::Basic) args = CVecBasic() - ccall((:basic_get_args, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, args.ptr) + get_args!(args, ex) convert(Vector, args) end +function get_args!(args::CVecBasic, ex::Basic) + ccall((:basic_get_args, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, args.ptr) +end + ## so that Dicts will work basic_hash(ex::Basic) = ccall((:basic_hash, libsymengine), UInt, (Ref{Basic}, ), ex) # similar definition as in Base for general objects Base.hash(ex::Basic, h::UInt) = Base.hash_uint(3h - basic_hash(ex)) Base.hash(ex::BasicType, h::UInt) = hash(Basic(ex), h) -function coeff(b::Basic, x::Basic, n::Basic) - c = Basic() - ccall((:basic_coeff, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}, Ref{Basic}, Ref{Basic}), c, b, x, n) - return c -end - -coeff(b::Basic, x::Basic) = coeff(b, x, one(Basic)) function Serialization.serialize(s::Serialization.AbstractSerializer, m::Basic) Serialization.serialize_type(s, typeof(m)) diff --git a/test/runtests.jl b/test/runtests.jl index dc74e32..2336e45 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,7 +19,7 @@ let @vars w end @test_throws UndefVarError isdefined(w) -@test_throws Exception show(Basic()) +@test repr(Basic()) == "" # test @vars constructions @vars a, b[0:4], c(), d=>"D" @@ -414,3 +414,5 @@ end close(iobuf) @test deserialized == data end + +VERSION >= v"1.9.0" && include("test-allocations.jl") diff --git a/test/test-allocations.jl b/test/test-allocations.jl new file mode 100644 index 0000000..1f67dce --- /dev/null +++ b/test/test-allocations.jl @@ -0,0 +1,18 @@ +# test for allocations +@vars x y +a = Basic() +@testset "non-allocating(ish) methods" begin + sin(x), cos(x), abs(x) + x^x, x + x, x*x, x-x, x/x # warm up + + @test (@allocations SymEngine.sin!(a,x)) == 0 + @test (@allocations SymEngine.cos!(a,x)) == 0 + @test (@allocations SymEngine.pow!(a,x,x)) == 0 + + # still allocates 1 (or 2) + @test (@allocations SymEngine.add!(a,x,y)) < (@allocations x+y) + @test (@allocations SymEngine.sub!(a,x,y)) < (@allocations x-y) + @test (@allocations SymEngine.mul!(a,x,y)) < (@allocations x*y) + @test (@allocations SymEngine.div!(a,x,y)) < (@allocations x/y) + @test (@allocations SymEngine.abs2!(a,x)) < (@allocations abs2(x)) +end