Skip to content

Commit bd8602c

Browse files
authored
Merge pull request #288 from jverzani/mutate
use mutating versions of ops
2 parents 267124c + 67c80d8 commit bd8602c

File tree

11 files changed

+136
-53
lines changed

11 files changed

+136
-53
lines changed

src/SymEngine.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ const have_mpc = have_component("mpc")
2121
const libversion = get_libversion()
2222

2323
include("exceptions.jl")
24-
include("types.jl")
2524
include("ctypes.jl")
25+
include("types.jl")
2626
include("decl.jl")
2727
include("display.jl")
2828
include("mathops.jl")

src/ctypes.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,21 @@
11
# types from SymEngine to Julia
2+
3+
## Basic
4+
## Hold a reference to a SymEngine object
5+
mutable struct Basic <: Number
6+
ptr::Ptr{Cvoid}
7+
function Basic()
8+
z = new(C_NULL)
9+
ccall((:basic_new_stack, libsymengine), Nothing, (Ref{Basic}, ), z)
10+
finalizer(basic_free, z)
11+
return z
12+
end
13+
function Basic(v::Ptr{Cvoid})
14+
z = new(v)
15+
return z
16+
end
17+
end
18+
219
## CSetBasic
320
mutable struct CSetBasic
421
ptr::Ptr{Cvoid}
@@ -160,7 +177,7 @@ function CDenseMatrix(x::Array{T, 2}) where T
160177
end
161178

162179

163-
function Base.convert(::Type{Matrix}, x::CDenseMatrix)
180+
function Base.convert(::Type{Matrix}, x::CDenseMatrix)
164181
m,n = Base.size(x)
165182
[x[i,j] for i in 1:m, j in 1:n]
166183
end

src/display.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
function toString(b::SymbolicType)
33
b = Basic(b)
44
if b.ptr == C_NULL
5-
error("Trying to print an uninitialized SymEngine Basic variable.")
5+
return "<Unitialized Basic value>"
66
end
77
a = ccall((:basic_str_julia, libsymengine), Cstring, (Ref{Basic}, ), b)
88
string = unsafe_string(a)

src/mathfuns.jl

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,43 @@
11
using SpecialFunctions
22

3-
function IMPLEMENT_ONE_ARG_FUNC(meth, symnm; lib=:basic_)
3+
function IMPLEMENT_ONE_ARG_FUNC(modu, meth, symnm; lib=:basic_)
4+
methbang = Symbol(meth, "!")
5+
if isa(modu, Symbol)
6+
meth = :($modu.$meth)
7+
end
48
@eval begin
59
function ($meth)(b::SymbolicType)
610
a = Basic()
11+
($methbang)(a,b)
12+
a
13+
end
14+
function ($methbang)(a::Basic, b::SymbolicType)
715
err_code = ccall(($(string(lib,symnm)), libsymengine), Cuint, (Ref{Basic}, Ref{Basic}), a, b)
816
throw_if_error(err_code, $(string(meth)))
917
return a
1018
end
1119
end
1220
end
1321

14-
function IMPLEMENT_TWO_ARG_FUNC(meth, symnm; lib=:basic_)
22+
function IMPLEMENT_TWO_ARG_FUNC(modu, meth, symnm; lib=:basic_)
23+
methbang = Symbol(meth, "!")
24+
if isa(modu, Symbol)
25+
meth = :($modu.$meth)
26+
end
27+
1528
@eval begin
1629
function ($meth)(b1::SymbolicType, b2::Number)
1730
a = Basic()
31+
($methbang)(a,b1,b2)
32+
a
33+
end
34+
function ($methbang)(a::Basic, b1::SymbolicType, b2::Number)
1835
b1, b2 = promote(b1, b2)
1936
err_code = ccall(($(string(lib,symnm)), libsymengine), Cuint, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b1, b2)
2037
throw_if_error(err_code, $(string(meth)))
2138
return a
2239
end
40+
2341
end
2442
end
2543

@@ -59,7 +77,7 @@ for (meth, libnm, modu) in [
5977
(:floor, :floor, :Base)
6078
]
6179
eval(:(import $modu.$meth))
62-
IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm)
80+
IMPLEMENT_ONE_ARG_FUNC(modu, meth, libnm)
6381
end
6482

6583
for (meth, libnm, modu) in [
@@ -71,7 +89,7 @@ for (meth, libnm, modu) in [
7189
(:erfc, :erfc, :SpecialFunctions)
7290
]
7391
eval(:(import $modu.$meth))
74-
IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm)
92+
IMPLEMENT_ONE_ARG_FUNC(modu, meth, libnm)
7593
end
7694

7795
for (meth, libnm, modu) in [
@@ -80,23 +98,31 @@ for (meth, libnm, modu) in [
8098
(:loggamma,:loggamma,:SpecialFunctions),
8199
]
82100
eval(:(import $modu.$meth))
83-
IMPLEMENT_TWO_ARG_FUNC(:($modu.$meth), libnm)
101+
IMPLEMENT_TWO_ARG_FUNC(modu, meth, libnm)
84102
end
85103

86-
Base.abs2(x::SymEngine.Basic) = abs(x)^2
87-
104+
function abs2!(a::Basic, x::Basic)
105+
abs!(a, x)
106+
mul!(a, a, a)
107+
a
108+
end
109+
function Base.abs2(x::Basic)
110+
a = Basic()
111+
abs2!(a, x)
112+
a
113+
end
88114

89115

90116
if get_symbol(:basic_atan2) != C_NULL
91117
import Base.atan
92-
IMPLEMENT_TWO_ARG_FUNC(:(Base.atan), :atan2)
118+
IMPLEMENT_TWO_ARG_FUNC(:Base, :atan, :atan2)
93119
end
94120

95121
# export not import
96122
for (meth, libnm) in [
97123
(:lambertw,:lambertw), # in add-on packages, not base
98124
]
99-
IMPLEMENT_ONE_ARG_FUNC(meth, libnm)
125+
IMPLEMENT_ONE_ARG_FUNC(nothing, meth, libnm)
100126
eval(Expr(:export, meth))
101127
end
102128

@@ -118,7 +144,7 @@ for (meth, libnm) in [(:gcd, :gcd),
118144
(:mod, :mod_f),
119145
]
120146
eval(:(import Base.$meth))
121-
IMPLEMENT_TWO_ARG_FUNC(:(Base.$meth), libnm, lib=:ntheory_)
147+
IMPLEMENT_TWO_ARG_FUNC(:Base, meth, libnm; lib=:ntheory_)
122148
end
123149

124150
Base.binomial(n::Basic, k::Number) = binomial(N(n), N(k)) #ntheory_binomial seems wrong
@@ -129,18 +155,18 @@ Base.factorial(n::SymbolicType, k) = factorial(N(n), N(k))
129155
## but not (:fibonacci,:fibonacci), (:lucas, :lucas) (Basic type is not the signature)
130156
for (meth, libnm) in [(:nextprime,:nextprime)
131157
]
132-
IMPLEMENT_ONE_ARG_FUNC(meth, libnm, lib=:ntheory_)
158+
IMPLEMENT_ONE_ARG_FUNC(nothing, meth, libnm, lib=:ntheory_)
133159
eval(Expr(:export, meth))
134160
end
135161

136162
"Return coefficient of `x^n` term, `x` a symbol"
137-
function coeff(b::Basic, x, n)
138-
c = Basic()
163+
function coeff!(a::Basic, b::Basic, x, n)
139164
out = ccall((:basic_coeff, libsymengine), Nothing,
140165
(Ref{Basic},Ref{Basic},Ref{Basic},Ref{Basic}),
141-
c,b,Basic(x), Basic(n))
142-
c
166+
a,b,Basic(x), Basic(n))
167+
a
143168
end
169+
coeff(b::Basic, x, n) = coeff!(Basic(), b, x, n)
144170

145171
function Base.convert(::Type{CVecBasic}, x::Vector{T}) where T
146172
vec = CVecBasic()

src/mathops.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,20 @@ end
88

99

1010
## main ops
11-
for (op, libnm) in ((:+, :add), (:-, :sub), (:*, :mul), (:/, :div), (://, :div), (:^, :pow))
11+
for (op, libnm) in ((:+, :add), (:-, :sub), (:*, :mul), (:/, :div), (:^, :pow))
1212
tup = (Base.Symbol("basic_$libnm"), libsymengine)
13+
opbang = Symbol(libnm,:!)
1314
@eval begin
14-
function ($op)(b1::Basic, b2::Basic)
15-
a = Basic()
15+
function ($opbang)(a::Basic, b1::Basic, b2::Basic)
1616
err_code = ccall($tup, Cuint, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b1, b2)
1717
throw_if_error(err_code, $(string(libnm)))
1818
return a
1919
end
20+
function ($op)(b1::Basic, b2::Basic)
21+
a = Basic()
22+
($opbang)(a, b1, b2)
23+
return a
24+
end
2025
($op)(b1::BasicType, b2::BasicType) = ($op)(Basic(b1), Basic(b2))
2126
end
2227
end
@@ -30,16 +35,20 @@ end
3035
# In contrast to other standard operations such as `+`, `*`, `-`, and `/`,
3136
# Julia doesn't implement a general fallback of `//` for `Number`s promoting
3237
# the input arguments. Thus, we implement this here explicitly.
38+
Base.:(//)(b1::SymbolicType, b2::SymbolicType) = b1 / b2
3339
Base.:(//)(b1::SymbolicType, b2::Number) = //(promote(b1, b2)...)
3440
Base.:(//)(b1::Number, b2::SymbolicType) = //(promote(b1, b2)...)
3541

3642

37-
function sum(v::CVecBasic)
38-
a = Basic()
43+
function sum!(a::Basic, v::CVecBasic)
3944
err_code = ccall((:basic_add_vec, libsymengine), Cuint, (Ref{Basic}, Ptr{Cvoid}), a, v.ptr)
4045
throw_if_error(err_code, "add_vec")
4146
return a
4247
end
48+
function sum(v::CVecBasic)
49+
a = Basic()
50+
sum!(a, v)
51+
end
4352

4453
+(b1::Basic, b2::Basic, b3::Basic, bs...) = sum(convert(CVecBasic, [b1, b2, b3, bs...]))
4554
+(b1::Basic, b2::Basic, b3, bs...) = +(Basic(b1), Basic(b2), Basic(b3), bs...)
@@ -51,12 +60,15 @@ end
5160
+(b1, b2, b3::Basic, bs...) = +(Basic(b1), Basic(b2), Basic(b3), bs...)
5261

5362

54-
function prod(v::CVecBasic)
55-
a = Basic()
63+
function prod!(a::Basic, v::CVecBasic)
5664
err_code = ccall((:basic_mul_vec, libsymengine), Cuint, (Ref{Basic}, Ptr{Cvoid}), a, v.ptr)
5765
throw_if_error(err_code, "mul_vec")
5866
return a
5967
end
68+
function prod(v::CVecBasic)
69+
a = Basic()
70+
prod!(a, v)
71+
end
6072

6173
*(b1::Basic, b2::Basic, b3::Basic, bs::Vararg{Number, N}) where {N} = prod(convert(CVecBasic, [b1, b2, b3, bs...]))
6274
*(b1::Basic, b2::Basic, b3::Number, bs::Vararg{Number, N}) where {N} = *(Basic(b1), Basic(b2), Basic(b3), bs...)

src/numerics.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ end
131131
## deprecate N(::BasicType)
132132
N(b::BasicType{T}) where {T} = N(convert(Basic, b), T)
133133

134+
## Conversions SymEngine -> Julia
134135
## define convert(T, x) methods leveraging N() when needed
135136
function convert(::Type{Float64}, x::Basic)
136137
is_a_RealDouble(x) && return _convert(Cdouble, x)
@@ -170,11 +171,14 @@ Base.Real(x::Basic) = convert(Real, x)
170171

171172

172173
## Rational -- p/q parts
173-
function as_numer_denom(x::Basic)
174-
a, b = Basic(), Basic()
174+
function as_numer_denom!(a::Basic, b::Basic, x::Basic)
175175
ccall((:basic_as_numer_denom, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b, x)
176176
return a, b
177177
end
178+
function as_numer_denom(x::Basic)
179+
a, b = Basic(), Basic()
180+
as_numer_denom!(a,b,x)
181+
end
178182

179183
as_numer_denom(x::BasicType) = as_numer_denom(Basic(x))
180184
denominator(x::SymbolicType) = as_numer_denom(x)[2]

src/simplify.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
IMPLEMENT_ONE_ARG_FUNC(:expand, :expand)
1+
IMPLEMENT_ONE_ARG_FUNC(nothing, :expand, :expand)
22

33
if get_symbol(:basic_cse) != C_NULL
44
function cse(exprs...)

src/subs.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,26 @@ subs(ex, x=>1, y=>1) # ditto
1818
"""
1919
function subs(ex::T, var::S, val) where {T<:SymbolicType, S<:SymbolicType}
2020
s = Basic()
21+
subs!(s, ex, var, val)
22+
end
23+
24+
function subs!(s::Basic, ex::T, var::S, val) where {T<:SymbolicType, S<:SymbolicType}
2125
err_code = ccall((:basic_subs2, libsymengine), Cuint, (Ref{Basic}, Ref{Basic}, Ref{Basic}, Ref{Basic}), s, ex, var, val)
2226
throw_if_error(err_code, ex)
2327
return s
2428
end
29+
2530
function subs(ex::T, d::CMapBasicBasic) where T<:SymbolicType
2631
s = Basic()
32+
subs!(s, ex, d)
33+
end
34+
function subs!(s::Basic, ex::T, d::CMapBasicBasic) where T<:SymbolicType
2735
err_code = ccall((:basic_subs, libsymengine), Cuint, (Ref{Basic}, Ref{Basic}, Ptr{Cvoid}), s, ex, d.ptr)
2836
throw_if_error(err_code, ex)
2937
return s
3038
end
3139

40+
3241
subs(ex::T, d::AbstractDict) where {T<:SymbolicType} = subs(ex, CMapBasicBasic(d))
3342
subs(ex::T, y::Tuple{S, Any}) where {T <: SymbolicType, S<:SymbolicType} = subs(ex, y[1], y[2])
3443
subs(ex::T, y::Tuple{S, Any}, args...) where {T <: SymbolicType, S<:SymbolicType} = subs(subs(ex, y), args...)

src/types.jl

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,6 @@
1010
##
1111
## To control dispatch, one might have `N(b::Basic) = N(BasicType(b))` and then define `N` for types of interest
1212

13-
## Hold a reference to a SymEngine object
14-
mutable struct Basic <: Number
15-
ptr::Ptr{Cvoid}
16-
function Basic()
17-
z = new(C_NULL)
18-
ccall((:basic_new_stack, libsymengine), Nothing, (Ref{Basic}, ), z)
19-
finalizer(basic_free, z)
20-
return z
21-
end
22-
function Basic(v::Ptr{Cvoid})
23-
z = new(v)
24-
return z
25-
end
26-
end
2713

2814
basic_free(b::Basic) = ccall((:basic_free_stack, libsymengine), Nothing, (Ref{Basic}, ), b)
2915

@@ -293,20 +279,32 @@ end
293279
" Return free symbols in an expression as a `Set`"
294280
function free_symbols(ex::Basic)
295281
syms = CSetBasic()
296-
ccall((:basic_free_symbols, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, syms.ptr)
282+
free_symbols!(syms, ex)
297283
convert(Vector, syms)
298284
end
285+
function free_symbols!(syms::CSetBasic, ex::Basic)
286+
ccall((:basic_free_symbols, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, syms.ptr)
287+
syms
288+
end
289+
299290
free_symbols(ex::BasicType) = free_symbols(Basic(ex))
291+
300292
_flat(A) = mapreduce(x->isa(x,Array) ? _flat(x) : x, vcat, A, init=Basic[]) # from rosetta code example
293+
301294
free_symbols(exs::Array{T}) where {T<:SymbolicType} = unique(_flat([free_symbols(ex) for ex in exs]))
302295
free_symbols(exs::Tuple) = unique(_flat([free_symbols(ex) for ex in exs]))
303296

304297
"Return function symbols in an expression as a `Set`"
305298
function function_symbols(ex::Basic)
306299
syms = CSetBasic()
307-
ccall((:basic_function_symbols, libsymengine), Nothing, (Ptr{Cvoid}, Ref{Basic}), syms.ptr, ex)
300+
function_symbols!(syms, ex)
308301
convert(Vector, syms)
309302
end
303+
function function_symbols!(syms::CSetBasic, ex::Basic)
304+
ccall((:basic_function_symbols, libsymengine), Nothing, (Ptr{Cvoid}, Ref{Basic}), syms.ptr, ex)
305+
syms
306+
end
307+
310308
function_symbols(ex::BasicType) = function_symbols(Basic(ex))
311309
function_symbols(exs::Array{T}) where {T<:SymbolicType} = unique(_flat([function_symbols(ex) for ex in exs]))
312310
function_symbols(exs::Tuple) = unique(_flat([function_symbols(ex) for ex in exs]))
@@ -322,23 +320,20 @@ end
322320
"Return arguments of a function call as a vector of `Basic` objects"
323321
function get_args(ex::Basic)
324322
args = CVecBasic()
325-
ccall((:basic_get_args, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, args.ptr)
323+
get_args!(args, ex)
326324
convert(Vector, args)
327325
end
328326

327+
function get_args!(args::CVecBasic, ex::Basic)
328+
ccall((:basic_get_args, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, args.ptr)
329+
end
330+
329331
## so that Dicts will work
330332
basic_hash(ex::Basic) = ccall((:basic_hash, libsymengine), UInt, (Ref{Basic}, ), ex)
331333
# similar definition as in Base for general objects
332334
Base.hash(ex::Basic, h::UInt) = Base.hash_uint(3h - basic_hash(ex))
333335
Base.hash(ex::BasicType, h::UInt) = hash(Basic(ex), h)
334336

335-
function coeff(b::Basic, x::Basic, n::Basic)
336-
c = Basic()
337-
ccall((:basic_coeff, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}, Ref{Basic}, Ref{Basic}), c, b, x, n)
338-
return c
339-
end
340-
341-
coeff(b::Basic, x::Basic) = coeff(b, x, one(Basic))
342337

343338
function Serialization.serialize(s::Serialization.AbstractSerializer, m::Basic)
344339
Serialization.serialize_type(s, typeof(m))

0 commit comments

Comments
 (0)