Skip to content

Commit 0595115

Browse files
committed
use mutating versions of ops
1 parent d4f4c53 commit 0595115

File tree

8 files changed

+116
-37
lines changed

8 files changed

+116
-37
lines changed

src/display.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
function toString(b::SymbolicType)
33
b = Basic(b)
44
if b.ptr == C_NULL
5-
error("Trying to print an uninitialized SymEngine Basic variable.")
5+
return "<Unitialized Basic value>"
66
end
77
a = ccall((:basic_str_julia, libsymengine), Cstring, (Ref{Basic}, ), b)
88
string = unsafe_string(a)

src/mathfuns.jl

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,43 @@
11
using SpecialFunctions
22

3-
function IMPLEMENT_ONE_ARG_FUNC(meth, symnm; lib=:basic_)
3+
function IMPLEMENT_ONE_ARG_FUNC(modu, meth, symnm; lib=:basic_)
4+
methbang = Symbol(meth, "!")
5+
if isa(modu, Symbol)
6+
meth = :($modu.$meth)
7+
end
48
@eval begin
59
function ($meth)(b::SymbolicType)
610
a = Basic()
11+
($methbang)(a,b)
12+
a
13+
end
14+
function ($methbang)(a::Basic, b::SymbolicType)
715
err_code = ccall(($(string(lib,symnm)), libsymengine), Cuint, (Ref{Basic}, Ref{Basic}), a, b)
816
throw_if_error(err_code, $(string(meth)))
917
return a
1018
end
1119
end
1220
end
1321

14-
function IMPLEMENT_TWO_ARG_FUNC(meth, symnm; lib=:basic_)
22+
function IMPLEMENT_TWO_ARG_FUNC(modu, meth, symnm; lib=:basic_)
23+
methbang = Symbol(meth, "!")
24+
if isa(modu, Symbol)
25+
meth = :($modu.$meth)
26+
end
27+
1528
@eval begin
1629
function ($meth)(b1::SymbolicType, b2::Number)
1730
a = Basic()
31+
($methbang)(a,b1,b2)
32+
a
33+
end
34+
function ($methbang)(a::Basic, b1::SymbolicType, b2::Number)
1835
b1, b2 = promote(b1, b2)
1936
err_code = ccall(($(string(lib,symnm)), libsymengine), Cuint, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b1, b2)
2037
throw_if_error(err_code, $(string(meth)))
2138
return a
2239
end
40+
2341
end
2442
end
2543

@@ -59,7 +77,7 @@ for (meth, libnm, modu) in [
5977
(:floor, :floor, :Base)
6078
]
6179
eval(:(import $modu.$meth))
62-
IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm)
80+
IMPLEMENT_ONE_ARG_FUNC(modu, meth, libnm)
6381
end
6482

6583
for (meth, libnm, modu) in [
@@ -71,7 +89,7 @@ for (meth, libnm, modu) in [
7189
(:erfc, :erfc, :SpecialFunctions)
7290
]
7391
eval(:(import $modu.$meth))
74-
IMPLEMENT_ONE_ARG_FUNC(:($modu.$meth), libnm)
92+
IMPLEMENT_ONE_ARG_FUNC(modu, meth, libnm)
7593
end
7694

7795
for (meth, libnm, modu) in [
@@ -80,23 +98,32 @@ for (meth, libnm, modu) in [
8098
(:loggamma,:loggamma,:SpecialFunctions),
8199
]
82100
eval(:(import $modu.$meth))
83-
IMPLEMENT_TWO_ARG_FUNC(:($modu.$meth), libnm)
101+
IMPLEMENT_TWO_ARG_FUNC(modu, meth, libnm)
84102
end
85103

86-
Base.abs2(x::SymEngine.Basic) = abs(x)^2
87-
104+
const TWO = Basic(2)
105+
function abs2!(a::Basic, x::Basic)
106+
a = abs!(a, x)
107+
a = pow!(a, x, TWO)
108+
a
109+
end
110+
function Base.abs2(x::Basic)
111+
a = Basic()
112+
abs2!(a, x)
113+
a
114+
end
88115

89116

90117
if get_symbol(:basic_atan2) != C_NULL
91118
import Base.atan
92-
IMPLEMENT_TWO_ARG_FUNC(:(Base.atan), :atan2)
119+
IMPLEMENT_TWO_ARG_FUNC(:Base, :atan, :atan2)
93120
end
94121

95122
# export not import
96123
for (meth, libnm) in [
97124
(:lambertw,:lambertw), # in add-on packages, not base
98125
]
99-
IMPLEMENT_ONE_ARG_FUNC(meth, libnm)
126+
IMPLEMENT_ONE_ARG_FUNC(nothing, meth, libnm)
100127
eval(Expr(:export, meth))
101128
end
102129

@@ -118,7 +145,7 @@ for (meth, libnm) in [(:gcd, :gcd),
118145
(:mod, :mod_f),
119146
]
120147
eval(:(import Base.$meth))
121-
IMPLEMENT_TWO_ARG_FUNC(:(Base.$meth), libnm, lib=:ntheory_)
148+
IMPLEMENT_TWO_ARG_FUNC(:Base, meth, libnm; lib=:ntheory_)
122149
end
123150

124151
Base.binomial(n::Basic, k::Number) = binomial(N(n), N(k)) #ntheory_binomial seems wrong
@@ -129,18 +156,18 @@ Base.factorial(n::SymbolicType, k) = factorial(N(n), N(k))
129156
## but not (:fibonacci,:fibonacci), (:lucas, :lucas) (Basic type is not the signature)
130157
for (meth, libnm) in [(:nextprime,:nextprime)
131158
]
132-
IMPLEMENT_ONE_ARG_FUNC(meth, libnm, lib=:ntheory_)
159+
IMPLEMENT_ONE_ARG_FUNC(nothing, meth, libnm, lib=:ntheory_)
133160
eval(Expr(:export, meth))
134161
end
135162

136163
"Return coefficient of `x^n` term, `x` a symbol"
137-
function coeff(b::Basic, x, n)
138-
c = Basic()
164+
function coeff!(a::Basic, b::Basic, x, n)
139165
out = ccall((:basic_coeff, libsymengine), Nothing,
140166
(Ref{Basic},Ref{Basic},Ref{Basic},Ref{Basic}),
141-
c,b,Basic(x), Basic(n))
142-
c
167+
a,b,Basic(x), Basic(n))
168+
a
143169
end
170+
coeff(b::Basic, x, n) = coeff!(Basic(), b, x, n)
144171

145172
function Base.convert(::Type{CVecBasic}, x::Vector{T}) where T
146173
vec = CVecBasic()

src/mathops.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,20 @@ end
88

99

1010
## main ops
11-
for (op, libnm) in ((:+, :add), (:-, :sub), (:*, :mul), (:/, :div), (://, :div), (:^, :pow))
11+
for (op, libnm) in ((:+, :add), (:-, :sub), (:*, :mul), (:/, :div), (:^, :pow))
1212
tup = (Base.Symbol("basic_$libnm"), libsymengine)
13+
opbang = Symbol(libnm,:!)
1314
@eval begin
14-
function ($op)(b1::Basic, b2::Basic)
15-
a = Basic()
15+
function ($opbang)(a::Basic, b1::Basic, b2::Basic)
1616
err_code = ccall($tup, Cuint, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b1, b2)
1717
throw_if_error(err_code, $(string(libnm)))
1818
return a
1919
end
20+
function ($op)(b1::Basic, b2::Basic)
21+
a = Basic()
22+
($opbang)(a, b1, b2)
23+
return a
24+
end
2025
($op)(b1::BasicType, b2::BasicType) = ($op)(Basic(b1), Basic(b2))
2126
end
2227
end
@@ -30,16 +35,20 @@ end
3035
# In contrast to other standard operations such as `+`, `*`, `-`, and `/`,
3136
# Julia doesn't implement a general fallback of `//` for `Number`s promoting
3237
# the input arguments. Thus, we implement this here explicitly.
38+
Base.:(//)(b1::SymbolicType, b2::SymbolicType) = b1 / b2
3339
Base.:(//)(b1::SymbolicType, b2::Number) = //(promote(b1, b2)...)
3440
Base.:(//)(b1::Number, b2::SymbolicType) = //(promote(b1, b2)...)
3541

3642

37-
function sum(v::CVecBasic)
38-
a = Basic()
43+
function sum!(a::Basic, v::CVecBasic)
3944
err_code = ccall((:basic_add_vec, libsymengine), Cuint, (Ref{Basic}, Ptr{Cvoid}), a, v.ptr)
4045
throw_if_error(err_code, "add_vec")
4146
return a
4247
end
48+
function sum(v::CVecBasic)
49+
a = Basic()
50+
sum!(a, b)
51+
end
4352

4453
+(b1::Basic, b2::Basic, b3::Basic, bs...) = sum(convert(CVecBasic, [b1, b2, b3, bs...]))
4554
+(b1::Basic, b2::Basic, b3, bs...) = +(Basic(b1), Basic(b2), Basic(b3), bs...)
@@ -51,12 +60,15 @@ end
5160
+(b1, b2, b3::Basic, bs...) = +(Basic(b1), Basic(b2), Basic(b3), bs...)
5261

5362

54-
function prod(v::CVecBasic)
55-
a = Basic()
63+
function prod!(a::Basic, v::CVecBasic)
5664
err_code = ccall((:basic_mul_vec, libsymengine), Cuint, (Ref{Basic}, Ptr{Cvoid}), a, v.ptr)
5765
throw_if_error(err_code, "mul_vec")
5866
return a
5967
end
68+
function prod(v::CVecBasic)
69+
a = Basic()
70+
prod!(a, b)
71+
end
6072

6173
*(b1::Basic, b2::Basic, b3::Basic, bs::Vararg{Number, N}) where {N} = prod(convert(CVecBasic, [b1, b2, b3, bs...]))
6274
*(b1::Basic, b2::Basic, b3::Number, bs::Vararg{Number, N}) where {N} = *(Basic(b1), Basic(b2), Basic(b3), bs...)

src/numerics.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,14 @@ end
159159

160160

161161
## Conversions SymEngine -> Julia
162-
function as_numer_denom(x::Basic)
163-
a, b = Basic(), Basic()
162+
function as_numer_denom!(a::Basic, b::Basic, x::Basic)
164163
ccall((:basic_as_numer_denom, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b, x)
165164
return a, b
166165
end
166+
function as_numer_denom(x::Basic)
167+
a, b = Basic(), Basic()
168+
as_numer_denom!(a,b,x)
169+
end
167170

168171
as_numer_denom(x::BasicType) = as_numer_denom(Basic(x))
169172
denominator(x::SymbolicType) = as_numer_denom(x)[2]

src/simplify.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
IMPLEMENT_ONE_ARG_FUNC(:expand, :expand)
1+
IMPLEMENT_ONE_ARG_FUNC(nothing, :expand, :expand)
22

33
if get_symbol(:basic_cse) != C_NULL
44
function cse(exprs...)

src/subs.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,26 @@ subs(ex, x=>1, y=>1) # ditto
1818
"""
1919
function subs(ex::T, var::S, val) where {T<:SymbolicType, S<:SymbolicType}
2020
s = Basic()
21+
subs!(s, ex, var, val)
22+
end
23+
24+
function subs!(s::Basic, ex::T, var::S, val) where {T<:SymbolicType, S<:SymbolicType}
2125
err_code = ccall((:basic_subs2, libsymengine), Cuint, (Ref{Basic}, Ref{Basic}, Ref{Basic}, Ref{Basic}), s, ex, var, val)
2226
throw_if_error(err_code, ex)
2327
return s
2428
end
29+
2530
function subs(ex::T, d::CMapBasicBasic) where T<:SymbolicType
2631
s = Basic()
32+
subs!(s, ex, d)
33+
end
34+
function subs!(s::Basic, ex::T, d::CMapBasicBasic) where T<:SymbolicType
2735
err_code = ccall((:basic_subs, libsymengine), Cuint, (Ref{Basic}, Ref{Basic}, Ptr{Cvoid}), s, ex, d.ptr)
2836
throw_if_error(err_code, ex)
2937
return s
3038
end
3139

40+
3241
subs(ex::T, d::AbstractDict) where {T<:SymbolicType} = subs(ex, CMapBasicBasic(d))
3342
subs(ex::T, y::Tuple{S, Any}) where {T <: SymbolicType, S<:SymbolicType} = subs(ex, y[1], y[2])
3443
subs(ex::T, y::Tuple{S, Any}, args...) where {T <: SymbolicType, S<:SymbolicType} = subs(subs(ex, y), args...)

src/types.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -280,20 +280,32 @@ end
280280
" Return free symbols in an expression as a `Set`"
281281
function free_symbols(ex::Basic)
282282
syms = CSetBasic()
283-
ccall((:basic_free_symbols, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, syms.ptr)
283+
free_symbols!(syms, ex)
284284
convert(Vector, syms)
285285
end
286+
function free_symbols!(syms, ex::Basic)
287+
ccall((:basic_free_symbols, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, syms.ptr)
288+
syms
289+
end
290+
286291
free_symbols(ex::BasicType) = free_symbols(Basic(ex))
292+
287293
_flat(A) = mapreduce(x->isa(x,Array) ? _flat(x) : x, vcat, A, init=Basic[]) # from rosetta code example
294+
288295
free_symbols(exs::Array{T}) where {T<:SymbolicType} = unique(_flat([free_symbols(ex) for ex in exs]))
289296
free_symbols(exs::Tuple) = unique(_flat([free_symbols(ex) for ex in exs]))
290297

291298
"Return function symbols in an expression as a `Set`"
292299
function function_symbols(ex::Basic)
293300
syms = CSetBasic()
294-
ccall((:basic_function_symbols, libsymengine), Nothing, (Ptr{Cvoid}, Ref{Basic}), syms.ptr, ex)
301+
function_symbols!(syms, ex)
295302
convert(Vector, syms)
296303
end
304+
function function_symbols!(syms, ex::Basic)
305+
ccall((:basic_function_symbols, libsymengine), Nothing, (Ptr{Cvoid}, Ref{Basic}), syms.ptr, ex)
306+
syms
307+
end
308+
297309
function_symbols(ex::BasicType) = function_symbols(Basic(ex))
298310
function_symbols(exs::Array{T}) where {T<:SymbolicType} = unique(_flat([function_symbols(ex) for ex in exs]))
299311
function_symbols(exs::Tuple) = unique(_flat([function_symbols(ex) for ex in exs]))
@@ -309,23 +321,20 @@ end
309321
"Return arguments of a function call as a vector of `Basic` objects"
310322
function get_args(ex::Basic)
311323
args = CVecBasic()
312-
ccall((:basic_get_args, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, args.ptr)
324+
get_args!(args, ex)
313325
convert(Vector, args)
314326
end
315327

328+
function get_args!(args, ex::Basic)
329+
ccall((:basic_get_args, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, args.ptr)
330+
end
331+
316332
## so that Dicts will work
317333
basic_hash(ex::Basic) = ccall((:basic_hash, libsymengine), UInt, (Ref{Basic}, ), ex)
318334
# similar definition as in Base for general objects
319335
Base.hash(ex::Basic, h::UInt) = Base.hash_uint(3h - basic_hash(ex))
320336
Base.hash(ex::BasicType, h::UInt) = hash(Basic(ex), h)
321337

322-
function coeff(b::Basic, x::Basic, n::Basic)
323-
c = Basic()
324-
ccall((:basic_coeff, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}, Ref{Basic}, Ref{Basic}), c, b, x, n)
325-
return c
326-
end
327-
328-
coeff(b::Basic, x::Basic) = coeff(b, x, one(Basic))
329338

330339
function Serialization.serialize(s::Serialization.AbstractSerializer, m::Basic)
331340
Serialization.serialize_type(s, typeof(m))

test/runtests.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ let
1919
@vars w
2020
end
2121
@test_throws UndefVarError isdefined(w)
22-
@test_throws Exception show(Basic())
22+
@test repr(Basic()) == "<Unitialized Basic value>"
2323

2424
# test @vars constructions
2525
@vars a, b[0:4], c(), d=>"D"
@@ -354,3 +354,22 @@ end
354354
close(iobuf)
355355
@test deserialized == data
356356
end
357+
358+
@vars a x y
359+
@testset "non-allocating methods" begin
360+
SymEngine.sin!(a,x); SymEngine.cos!(a,x); SymEngine.abs!(a,x)
361+
SymEngine.pow!(a,x,x);
362+
SymEngine.add!(a,x,x);SymEngine.mul!(a,x,x)
363+
SymEngine.sub!(a,x,x);SymEngine.div!(a,x,x)
364+
@test (@allocations SymEngine.sin!(a,x)) == 0
365+
@test (@allocations SymEngine.cos!(a,x)) == 0
366+
@test (@allocations SymEngine.abs2!(a,x)) == 0
367+
@test (@allocations SymEngine.pow!(a,x,x)) == 0
368+
369+
# still allocates 1 (or 2)
370+
@test (@allocations SymEngine.add!(a,x,y)) < (@allocations x+y)
371+
@test (@allocations SymEngine.sub!(a,x,y)) < (@allocations x-y)
372+
@test (@allocations SymEngine.mul!(a,x,y)) < (@allocations x*y)
373+
@test (@allocations SymEngine.div!(a,x,y)) < (@allocations x/y)
374+
375+
end

0 commit comments

Comments
 (0)