Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/SymEngine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ const have_mpc = have_component("mpc")
const libversion = get_libversion()

include("exceptions.jl")
include("types.jl")
include("ctypes.jl")
include("types.jl")
include("decl.jl")
include("display.jl")
include("mathops.jl")
Expand Down
19 changes: 18 additions & 1 deletion src/ctypes.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,21 @@
# types from SymEngine to Julia

## Basic
## Hold a reference to a SymEngine object
mutable struct Basic <: Number
ptr::Ptr{Cvoid}
function Basic()
z = new(C_NULL)
ccall((:basic_new_stack, libsymengine), Nothing, (Ref{Basic}, ), z)
finalizer(basic_free, z)
return z
end
function Basic(v::Ptr{Cvoid})
z = new(v)
return z
end
end

## CSetBasic
mutable struct CSetBasic
ptr::Ptr{Cvoid}
Expand Down Expand Up @@ -160,7 +177,7 @@ function CDenseMatrix(x::Array{T, 2}) where T
end


function Base.convert(::Type{Matrix}, x::CDenseMatrix)
function Base.convert(::Type{Matrix}, x::CDenseMatrix)
m,n = Base.size(x)
[x[i,j] for i in 1:m, j in 1:n]
end
Expand Down
2 changes: 1 addition & 1 deletion src/display.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
function toString(b::SymbolicType)
b = Basic(b)
if b.ptr == C_NULL
error("Trying to print an uninitialized SymEngine Basic variable.")
return "<Unitialized Basic value>"
end
a = ccall((:basic_str_julia, libsymengine), Cstring, (Ref{Basic}, ), b)
string = unsafe_string(a)
Expand Down
56 changes: 41 additions & 15 deletions src/mathfuns.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,43 @@
using SpecialFunctions

function IMPLEMENT_ONE_ARG_FUNC(meth, symnm; lib=:basic_)
function IMPLEMENT_ONE_ARG_FUNC(modu, meth, symnm; lib=:basic_)
methbang = Symbol(meth, "!")
if isa(modu, Symbol)
meth = :($modu.$meth)
end
@eval begin
function ($meth)(b::SymbolicType)
a = Basic()
($methbang)(a,b)
a
end
function ($methbang)(a::Basic, b::SymbolicType)
err_code = ccall(($(string(lib,symnm)), libsymengine), Cuint, (Ref{Basic}, Ref{Basic}), a, b)
throw_if_error(err_code, $(string(meth)))
return a
end
end
end

function IMPLEMENT_TWO_ARG_FUNC(meth, symnm; lib=:basic_)
function IMPLEMENT_TWO_ARG_FUNC(modu, meth, symnm; lib=:basic_)
methbang = Symbol(meth, "!")
if isa(modu, Symbol)
meth = :($modu.$meth)
end

@eval begin
function ($meth)(b1::SymbolicType, b2::Number)
a = Basic()
($methbang)(a,b1,b2)
a
end
function ($methbang)(a::Basic, b1::SymbolicType, b2::Number)
b1, b2 = promote(b1, b2)
err_code = ccall(($(string(lib,symnm)), libsymengine), Cuint, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b1, b2)
throw_if_error(err_code, $(string(meth)))
return a
end

end
end

Expand Down Expand Up @@ -59,7 +77,7 @@ for (meth, libnm, modu) in [
(:floor, :floor, :Base)
]
eval(:(import $modu.$meth))
IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm)
IMPLEMENT_ONE_ARG_FUNC(modu, meth, libnm)
end

for (meth, libnm, modu) in [
Expand All @@ -71,7 +89,7 @@ for (meth, libnm, modu) in [
(:erfc, :erfc, :SpecialFunctions)
]
eval(:(import $modu.$meth))
IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm)
IMPLEMENT_ONE_ARG_FUNC(modu, meth, libnm)
end

for (meth, libnm, modu) in [
Expand All @@ -80,23 +98,31 @@ for (meth, libnm, modu) in [
(:loggamma,:loggamma,:SpecialFunctions),
]
eval(:(import $modu.$meth))
IMPLEMENT_TWO_ARG_FUNC(:($modu.$meth), libnm)
IMPLEMENT_TWO_ARG_FUNC(modu, meth, libnm)
end

Base.abs2(x::SymEngine.Basic) = abs(x)^2

function abs2!(a::Basic, x::Basic)
abs!(a, x)
mul!(a, a, a)
a
end
function Base.abs2(x::Basic)
a = Basic()
abs2!(a, x)
a
end


if get_symbol(:basic_atan2) != C_NULL
import Base.atan
IMPLEMENT_TWO_ARG_FUNC(:(Base.atan), :atan2)
IMPLEMENT_TWO_ARG_FUNC(:Base, :atan, :atan2)
end

# export not import
for (meth, libnm) in [
(:lambertw,:lambertw), # in add-on packages, not base
]
IMPLEMENT_ONE_ARG_FUNC(meth, libnm)
IMPLEMENT_ONE_ARG_FUNC(nothing, meth, libnm)
eval(Expr(:export, meth))
end

Expand All @@ -118,7 +144,7 @@ for (meth, libnm) in [(:gcd, :gcd),
(:mod, :mod_f),
]
eval(:(import Base.$meth))
IMPLEMENT_TWO_ARG_FUNC(:(Base.$meth), libnm, lib=:ntheory_)
IMPLEMENT_TWO_ARG_FUNC(:Base, meth, libnm; lib=:ntheory_)
end

Base.binomial(n::Basic, k::Number) = binomial(N(n), N(k)) #ntheory_binomial seems wrong
Expand All @@ -129,18 +155,18 @@ Base.factorial(n::SymbolicType, k) = factorial(N(n), N(k))
## but not (:fibonacci,:fibonacci), (:lucas, :lucas) (Basic type is not the signature)
for (meth, libnm) in [(:nextprime,:nextprime)
]
IMPLEMENT_ONE_ARG_FUNC(meth, libnm, lib=:ntheory_)
IMPLEMENT_ONE_ARG_FUNC(nothing, meth, libnm, lib=:ntheory_)
eval(Expr(:export, meth))
end

"Return coefficient of `x^n` term, `x` a symbol"
function coeff(b::Basic, x, n)
c = Basic()
function coeff!(a::Basic, b::Basic, x, n)
out = ccall((:basic_coeff, libsymengine), Nothing,
(Ref{Basic},Ref{Basic},Ref{Basic},Ref{Basic}),
c,b,Basic(x), Basic(n))
c
a,b,Basic(x), Basic(n))
a
end
coeff(b::Basic, x, n) = coeff!(Basic(), b, x, n)

function Base.convert(::Type{CVecBasic}, x::Vector{T}) where T
vec = CVecBasic()
Expand Down
26 changes: 19 additions & 7 deletions src/mathops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@ end


## main ops
for (op, libnm) in ((:+, :add), (:-, :sub), (:*, :mul), (:/, :div), (://, :div), (:^, :pow))
for (op, libnm) in ((:+, :add), (:-, :sub), (:*, :mul), (:/, :div), (:^, :pow))
tup = (Base.Symbol("basic_$libnm"), libsymengine)
opbang = Symbol(libnm,:!)
@eval begin
function ($op)(b1::Basic, b2::Basic)
a = Basic()
function ($opbang)(a::Basic, b1::Basic, b2::Basic)
err_code = ccall($tup, Cuint, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b1, b2)
throw_if_error(err_code, $(string(libnm)))
return a
end
function ($op)(b1::Basic, b2::Basic)
a = Basic()
($opbang)(a, b1, b2)
return a
end
($op)(b1::BasicType, b2::BasicType) = ($op)(Basic(b1), Basic(b2))
end
end
Expand All @@ -30,16 +35,20 @@ end
# In contrast to other standard operations such as `+`, `*`, `-`, and `/`,
# Julia doesn't implement a general fallback of `//` for `Number`s promoting
# the input arguments. Thus, we implement this here explicitly.
Base.:(//)(b1::SymbolicType, b2::SymbolicType) = b1 / b2
Base.:(//)(b1::SymbolicType, b2::Number) = //(promote(b1, b2)...)
Base.:(//)(b1::Number, b2::SymbolicType) = //(promote(b1, b2)...)


function sum(v::CVecBasic)
a = Basic()
function sum!(a::Basic, v::CVecBasic)
err_code = ccall((:basic_add_vec, libsymengine), Cuint, (Ref{Basic}, Ptr{Cvoid}), a, v.ptr)
throw_if_error(err_code, "add_vec")
return a
end
function sum(v::CVecBasic)
a = Basic()
sum!(a, v)
end

+(b1::Basic, b2::Basic, b3::Basic, bs...) = sum(convert(CVecBasic, [b1, b2, b3, bs...]))
+(b1::Basic, b2::Basic, b3, bs...) = +(Basic(b1), Basic(b2), Basic(b3), bs...)
Expand All @@ -51,12 +60,15 @@ end
+(b1, b2, b3::Basic, bs...) = +(Basic(b1), Basic(b2), Basic(b3), bs...)


function prod(v::CVecBasic)
a = Basic()
function prod!(a::Basic, v::CVecBasic)
err_code = ccall((:basic_mul_vec, libsymengine), Cuint, (Ref{Basic}, Ptr{Cvoid}), a, v.ptr)
throw_if_error(err_code, "mul_vec")
return a
end
function prod(v::CVecBasic)
a = Basic()
prod!(a, v)
end

*(b1::Basic, b2::Basic, b3::Basic, bs::Vararg{Number, N}) where {N} = prod(convert(CVecBasic, [b1, b2, b3, bs...]))
*(b1::Basic, b2::Basic, b3::Number, bs::Vararg{Number, N}) where {N} = *(Basic(b1), Basic(b2), Basic(b3), bs...)
Expand Down
8 changes: 6 additions & 2 deletions src/numerics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ end
## deprecate N(::BasicType)
N(b::BasicType{T}) where {T} = N(convert(Basic, b), T)

## Conversions SymEngine -> Julia
## define convert(T, x) methods leveraging N() when needed
function convert(::Type{Float64}, x::Basic)
is_a_RealDouble(x) && return _convert(Cdouble, x)
Expand Down Expand Up @@ -170,11 +171,14 @@ Base.Real(x::Basic) = convert(Real, x)


## Rational -- p/q parts
function as_numer_denom(x::Basic)
a, b = Basic(), Basic()
function as_numer_denom!(a::Basic, b::Basic, x::Basic)
ccall((:basic_as_numer_denom, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b, x)
return a, b
end
function as_numer_denom(x::Basic)
a, b = Basic(), Basic()
as_numer_denom!(a,b,x)
end

as_numer_denom(x::BasicType) = as_numer_denom(Basic(x))
denominator(x::SymbolicType) = as_numer_denom(x)[2]
Expand Down
2 changes: 1 addition & 1 deletion src/simplify.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
IMPLEMENT_ONE_ARG_FUNC(:expand, :expand)
IMPLEMENT_ONE_ARG_FUNC(nothing, :expand, :expand)

if get_symbol(:basic_cse) != C_NULL
function cse(exprs...)
Expand Down
9 changes: 9 additions & 0 deletions src/subs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,26 @@ subs(ex, x=>1, y=>1) # ditto
"""
function subs(ex::T, var::S, val) where {T<:SymbolicType, S<:SymbolicType}
s = Basic()
subs!(s, ex, var, val)
end

function subs!(s::Basic, ex::T, var::S, val) where {T<:SymbolicType, S<:SymbolicType}
err_code = ccall((:basic_subs2, libsymengine), Cuint, (Ref{Basic}, Ref{Basic}, Ref{Basic}, Ref{Basic}), s, ex, var, val)
throw_if_error(err_code, ex)
return s
end

function subs(ex::T, d::CMapBasicBasic) where T<:SymbolicType
s = Basic()
subs!(s, ex, d)
end
function subs!(s::Basic, ex::T, d::CMapBasicBasic) where T<:SymbolicType
err_code = ccall((:basic_subs, libsymengine), Cuint, (Ref{Basic}, Ref{Basic}, Ptr{Cvoid}), s, ex, d.ptr)
throw_if_error(err_code, ex)
return s
end


subs(ex::T, d::AbstractDict) where {T<:SymbolicType} = subs(ex, CMapBasicBasic(d))
subs(ex::T, y::Tuple{S, Any}) where {T <: SymbolicType, S<:SymbolicType} = subs(ex, y[1], y[2])
subs(ex::T, y::Tuple{S, Any}, args...) where {T <: SymbolicType, S<:SymbolicType} = subs(subs(ex, y), args...)
Expand Down
43 changes: 19 additions & 24 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,6 @@
##
## To control dispatch, one might have `N(b::Basic) = N(BasicType(b))` and then define `N` for types of interest

## Hold a reference to a SymEngine object
mutable struct Basic <: Number
ptr::Ptr{Cvoid}
function Basic()
z = new(C_NULL)
ccall((:basic_new_stack, libsymengine), Nothing, (Ref{Basic}, ), z)
finalizer(basic_free, z)
return z
end
function Basic(v::Ptr{Cvoid})
z = new(v)
return z
end
end

basic_free(b::Basic) = ccall((:basic_free_stack, libsymengine), Nothing, (Ref{Basic}, ), b)

Expand Down Expand Up @@ -293,20 +279,32 @@ end
" Return free symbols in an expression as a `Set`"
function free_symbols(ex::Basic)
syms = CSetBasic()
ccall((:basic_free_symbols, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, syms.ptr)
free_symbols!(syms, ex)
convert(Vector, syms)
end
function free_symbols!(syms::CSetBasic, ex::Basic)
ccall((:basic_free_symbols, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, syms.ptr)
syms
end

free_symbols(ex::BasicType) = free_symbols(Basic(ex))

_flat(A) = mapreduce(x->isa(x,Array) ? _flat(x) : x, vcat, A, init=Basic[]) # from rosetta code example

free_symbols(exs::Array{T}) where {T<:SymbolicType} = unique(_flat([free_symbols(ex) for ex in exs]))
free_symbols(exs::Tuple) = unique(_flat([free_symbols(ex) for ex in exs]))

"Return function symbols in an expression as a `Set`"
function function_symbols(ex::Basic)
syms = CSetBasic()
ccall((:basic_function_symbols, libsymengine), Nothing, (Ptr{Cvoid}, Ref{Basic}), syms.ptr, ex)
function_symbols!(syms, ex)
convert(Vector, syms)
end
function function_symbols!(syms::CSetBasic, ex::Basic)
ccall((:basic_function_symbols, libsymengine), Nothing, (Ptr{Cvoid}, Ref{Basic}), syms.ptr, ex)
syms
end

function_symbols(ex::BasicType) = function_symbols(Basic(ex))
function_symbols(exs::Array{T}) where {T<:SymbolicType} = unique(_flat([function_symbols(ex) for ex in exs]))
function_symbols(exs::Tuple) = unique(_flat([function_symbols(ex) for ex in exs]))
Expand All @@ -322,23 +320,20 @@ end
"Return arguments of a function call as a vector of `Basic` objects"
function get_args(ex::Basic)
args = CVecBasic()
ccall((:basic_get_args, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, args.ptr)
get_args!(args, ex)
convert(Vector, args)
end

function get_args!(args::CVecBasic, ex::Basic)
ccall((:basic_get_args, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, args.ptr)
end

## so that Dicts will work
basic_hash(ex::Basic) = ccall((:basic_hash, libsymengine), UInt, (Ref{Basic}, ), ex)
# similar definition as in Base for general objects
Base.hash(ex::Basic, h::UInt) = Base.hash_uint(3h - basic_hash(ex))
Base.hash(ex::BasicType, h::UInt) = hash(Basic(ex), h)

function coeff(b::Basic, x::Basic, n::Basic)
c = Basic()
ccall((:basic_coeff, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}, Ref{Basic}, Ref{Basic}), c, b, x, n)
return c
end

coeff(b::Basic, x::Basic) = coeff(b, x, one(Basic))

function Serialization.serialize(s::Serialization.AbstractSerializer, m::Basic)
Serialization.serialize_type(s, typeof(m))
Expand Down
Loading
Loading