Skip to content

Commit 6519f4c

Browse files
authored
Merge pull request #285 from jverzani/N_speedup
2 parents d4f4c53 + 50a93d2 commit 6519f4c

File tree

4 files changed

+223
-115
lines changed

4 files changed

+223
-115
lines changed

src/mathops.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ Base.one(::Type{T}) where {T<:BasicType} = BasicType(Basic(1))
8383
## Math constants
8484
## no oo!
8585

86-
for op in [:IM, :PI, :E, :EulerGamma, :Catalan, :oo, :zoo, :NAN]
86+
for op in [:IM, :PI, :E, :EulerGamma, :Catalan, :GoldenRatio, :oo, :zoo, :NAN]
8787
@eval begin
8888
const $op = Basic(C_NULL)
8989
end
@@ -108,6 +108,7 @@ function init_constants()
108108
@init_constant E E
109109
@init_constant EulerGamma EulerGamma
110110
@init_constant Catalan Catalan
111+
@init_constant GoldenRatio GoldenRatio
111112
@init_constant oo infinity
112113
@init_constant zoo complex_infinity
113114
@init_constant NAN nan

src/numerics.jl

Lines changed: 153 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import Base: trunc, ceil, floor, round
44

55

66
function evalf(b::Basic, bits::Integer=53, real::Bool=false)
7-
!isfinite(b) && return b
7+
(b == oo || b == zoo || b == NAN) && return b
88
c = Basic()
99
bits > 53 && real && (have_mpfr || throw(ArgumentError("libsymengine has to be compiled with MPFR for this feature")))
1010
bits > 53 && !real && (have_mpc || throw(ArgumentError("libsymengine has to be compiled with MPC for this feature")))
@@ -16,89 +16,42 @@ function evalf(b::Basic, bits::Integer=53, real::Bool=false)
1616
end
1717
end
1818

19-
## Conversions from SymEngine -> Julia at the ccall level
20-
function convert(::Type{BigInt}, b::BasicType{Val{:Integer}})
19+
## Conversions from SymEngine.Basic -> Julia at the ccall level
20+
function _convert(::Type{BigInt}, b::Basic)
2121
a = BigInt()
22-
c = Basic(b)
23-
ccall((:integer_get_mpz, libsymengine), Nothing, (Ref{BigInt}, Ref{Basic}), a, c)
22+
_convert_bigint!(a, b)
2423
return a
2524
end
2625

26+
function _convert_bigint!(a::BigInt, b::Basic) # non-allocating (sometimes)
27+
is_a_Integer(b) || throw(ArgumentError("Not an integer"))
28+
ccall((:integer_get_mpz, libsymengine), Nothing, (Ref{BigInt}, Ref{Basic}), a, b)
29+
a
30+
end
2731

28-
function convert(::Type{BigFloat}, b::BasicType{Val{:RealMPFR}})
29-
c = Basic(b)
32+
function _convert(::Type{Int}, b::Basic)
33+
is_a_Integer(b) || throw(ArgumentError("Not an integer"))
34+
ccall((:integer_get_si, libsymengine), Int, (Ref{Basic},), b)
35+
end
36+
37+
function _convert(::Type{BigFloat}, b::Basic)
3038
a = BigFloat()
31-
ccall((:real_mpfr_get, libsymengine), Nothing, (Ref{BigFloat}, Ref{Basic}), a, c)
39+
_convert_bigfloat!(a, b)
3240
return a
3341
end
3442

35-
function convert(::Type{Cdouble}, b::BasicType{Val{:RealDouble}})
36-
c = Basic(b)
37-
return ccall((:real_double_get_d, libsymengine), Cdouble, (Ref{Basic},), c)
43+
function _convert_bigfloat!(a::BigFloat, b::Basic) # non-allocating
44+
is_a_RealMPFR(b) || throw("Not a big value")
45+
ccall((:real_mpfr_get, libsymengine), Nothing, (Ref{BigFloat}, Ref{Basic}), a, b)
46+
a
3847
end
3948

40-
if SymEngine.libversion >= VersionNumber("0.4.0")
41-
42-
function real(b::BasicComplexNumber)
43-
c = Basic(b)
44-
a = Basic()
45-
ccall((:complex_base_real_part, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}), a, c)
46-
return a
47-
end
48-
49-
function imag(b::BasicComplexNumber)
50-
c = Basic(b)
51-
a = Basic()
52-
ccall((:complex_base_imaginary_part, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}), a, c)
53-
return a
54-
end
55-
56-
else
57-
58-
function real(b::BasicType{Val{:ComplexDouble}})
59-
c = Basic(b)
60-
a = Basic()
61-
ccall((:complex_double_real_part, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}), a, c)
62-
return a
63-
end
64-
65-
function imag(b::BasicType{Val{:ComplexDouble}})
66-
c = Basic(b)
67-
a = Basic()
68-
ccall((:complex_double_imaginary_part, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}), a, c)
69-
return a
70-
end
71-
72-
function real(b::BasicType{Val{:Complex}})
73-
c = Basic(b)
74-
a = Basic()
75-
ccall((:complex_real_part, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}), a, c)
76-
return a
77-
end
78-
79-
function imag(b::BasicType{Val{:Complex}})
80-
c = Basic(b)
81-
a = Basic()
82-
ccall((:complex_imaginary_part, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}), a, c)
83-
return a
84-
end
85-
86-
function real(b::BasicType{Val{:ComplexMPC}})
87-
c = Basic(b)
88-
a = Basic()
89-
ccall((:complex_mpc_real_part, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}), a, c)
90-
return a
91-
end
92-
93-
function imag(b::BasicType{Val{:ComplexMPC}})
94-
c = Basic(b)
95-
a = Basic()
96-
ccall((:complex_mpc_imaginary_part, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}), a, c)
97-
return a
98-
end
99-
49+
function _convert(::Type{Cdouble}, b::Basic)
50+
is_a_RealDouble(b) || throw(ArgumentError("Not a real double"))
51+
return ccall((:real_double_get_d, libsymengine), Cdouble, (Ref{Basic},), b)
10052
end
10153

54+
10255
##################################################
10356
# N
10457
"""
@@ -109,10 +62,11 @@ Convert a SymEngine numeric value into a Julian number
10962
N(a::Integer) = a
11063
N(a::Rational) = a
11164
N(a::Complex) = a
112-
N(b::Basic) = N(BasicType(b))
11365

114-
function N(b::BasicType{Val{:Integer}})
115-
a = convert(BigInt, b)
66+
N(b::Basic) = N(get_symengine_class_val(b), b)
67+
68+
function N(::Val{:Integer}, b::Basic)
69+
a = _convert(BigInt, b)
11670
if (a.size > 1 || a.size < -1)
11771
return a
11872
elseif (a.size == 0)
@@ -130,35 +84,92 @@ function N(b::BasicType{Val{:Integer}})
13084
end
13185
end
13286

133-
N(b::BasicType{Val{:Rational}}) = Rational(N(numerator(b)), N(denominator(b))) # TODO: conditionally wrap rational_get_mpq from cwrapper.h
134-
N(b::BasicType{Val{:RealDouble}}) = convert(Cdouble, b)
135-
N(b::BasicType{Val{:RealMPFR}}) = convert(BigFloat, b)
136-
N(b::BasicType{Val{:NaN}}) = NaN
137-
function N(b::BasicType{Val{:Infty}})
138-
b == oo && return Inf
139-
b == -oo && return -Inf
140-
b == zoo && return Complex(Inf, Inf)
87+
# TODO: conditionally wrap rational_get_mpq from cwrapper.h
88+
N(::Val{:Rational}, b::Basic) = Rational(N(numerator(b)), N(denominator(b)))
89+
N(::Val{:RealDouble}, b::Basic) = _convert(Cdouble, b)
90+
N(::Val{:RealMPFR}, b::Basic) = _convert(BigFloat, b)
91+
N(::Val{:Complex}, b::Basic) = complex(N(real(b)), N(imag(b)))
92+
N(::Val{:ComplexMPC}, b::Basic) = complex(N(real(b)), N(imag(b)))
93+
N(::Val{:ComplexDouble}, b::Basic) = complex(N(real(b)), N(imag(b)))
94+
95+
N(::Val{:NaN}, b::Basic) = NaN
96+
function N(::Val{:Infty}, b::Basic)
97+
if b == oo
98+
return Inf
99+
elseif b == zoo
100+
return Complex(Inf,Inf)
101+
elseif b == -oo
102+
return -Inf
103+
else
104+
throw(ArgumentError("Unknown infinity symbol"))
105+
end
141106
end
142107

143-
## Mapping of SymEngine Constants into julia values
144-
constant_map = Dict("pi" => π, "eulergamma" => γ, "exp(1)" => e, "catalan" => catalan,
145-
"goldenratio" => φ)
146-
147-
N(b::BasicType{Val{:Constant}}) = constant_map[toString(b)]
108+
function N(::Val{:Constant}, b::Basic)
109+
if b == PI
110+
return π
111+
elseif b == EulerGamma
112+
return γ
113+
elseif b == E
114+
return
115+
elseif b == Catalan
116+
return catalan
117+
elseif b == GoldenRatio
118+
return φ
119+
else
120+
throw(ArgumentError("Unknown constant"))
121+
end
122+
end
148123

149-
N(b::BasicComplexNumber) = complex(N(real(b)), N(imag(b)))
150-
function N(b::BasicType)
151-
b = convert(Basic, b)
152-
fs = free_symbols(b)
153-
if length(fs) > 0
124+
function N(::Val{<:Any}, b::Basic)
125+
is_constant(b) ||
154126
throw(ArgumentError("Object can have no free symbols"))
155-
end
156127
out = evalf(b)
157-
imag(out) == Basic(0.0) ? real(out) : out
128+
imag(out) == Basic(0.0) ? N(real(out)) : N(out)
158129
end
159130

131+
## deprecate N(::BasicType)
132+
N(b::BasicType{T}) where {T} = N(convert(Basic, b), T)
160133

161-
## Conversions SymEngine -> Julia
134+
## define convert(T, x) methods leveraging N() when needed
135+
function convert(::Type{Float64}, x::Basic)
136+
is_a_RealDouble(x) && return _convert(Cdouble, x)
137+
convert(Float64, N(evalf(x, 53, true)))
138+
end
139+
140+
function convert(::Type{BigFloat}, x::Basic)
141+
is_a_RealMPFR(x) && return _convert(BigFloat, x)
142+
convert(BigFloat, N(evalf(x, precision(BigFloat), true)))
143+
end
144+
145+
function convert(::Type{Complex{Float64}}, x::Basic)
146+
z = is_a_ComplexDouble(x) ? x : evalf(x, 53, false)
147+
a,b = _real(z), _imag(z)
148+
u,v = _convert(Cdouble, a), _convert(Cdouble, b)
149+
return complex(u,v)
150+
end
151+
152+
function convert(::Type{Complex{BigFloat}}, x::Basic)
153+
z = is_a_ComplexMPC(x) ? x : evalf(x, precision(BigFloat), false)
154+
a,b = _real(z), _imag(z)
155+
u,v = _convert(BigFloat, a), _convert(BigFloat, b)
156+
return complex(u,v)
157+
end
158+
159+
convert(::Type{Number}, x::Basic) = x
160+
convert(::Type{T}, x::Basic) where {T <: Real} = convert(T, N(x))
161+
convert(::Type{Complex{T}}, x::Basic) where {T <: Real} = convert(Complex{T}, N(x))
162+
163+
# Constructors no longer fall back to `convert` methods
164+
Base.Int64(x::Basic) = convert(Int64, x)
165+
Base.Int32(x::Basic) = convert(Int32, x)
166+
Base.Float32(x::Basic) = convert(Float32, x)
167+
Base.Float64(x::Basic) = convert(Float64, x)
168+
Base.BigInt(x::Basic) = convert(BigInt, x)
169+
Base.Real(x::Basic) = convert(Real, x)
170+
171+
172+
## Rational -- p/q parts
162173
function as_numer_denom(x::Basic)
163174
a, b = Basic(), Basic()
164175
ccall((:basic_as_numer_denom, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b, x)
@@ -170,6 +181,28 @@ denominator(x::SymbolicType) = as_numer_denom(x)[2]
170181
numerator(x::SymbolicType) = as_numer_denom(x)[1]
171182

172183
## Complex
184+
# b::Basic -> a::Basic
185+
function _real(b::Basic)
186+
if is_a_RealDouble(b) || is_a_RealMPFR(b) || is_a_Integer(b) || is_a_Rational(b)
187+
return b
188+
end
189+
if !(is_a_Complex(b) || is_a_ComplexDouble(b) || is_a_ComplexMPC(b))
190+
throw(ArgumentError("Not a complex number"))
191+
end
192+
a = Basic()
193+
ccall((:complex_base_real_part, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}), a, b)
194+
return a
195+
end
196+
197+
function _imag(b::Basic)
198+
if !(is_a_Complex(b) || is_a_ComplexDouble(b) || is_a_ComplexMPC(b))
199+
throw(ArgumentError("Not a complex number"))
200+
end
201+
a = Basic()
202+
ccall((:complex_base_imaginary_part, libsymengine), Nothing, (Ref{Basic}, Ref{Basic}), a, b)
203+
return a
204+
end
205+
173206
real(x::Basic) = Basic(real(SymEngine.BasicType(x)))
174207
real(x::SymEngine.BasicType) = x
175208

@@ -186,22 +219,6 @@ conj(x::Basic) = Basic(conj(SymEngine.BasicType(x)))
186219
# To allow future extension, we define the fallback on `BasicType``.
187220
conj(x::BasicType) = 2 * real(x.x) - x.x
188221

189-
## define convert(T, x) methods leveraging N()
190-
convert(::Type{Float64}, x::Basic) = convert(Float64, N(evalf(x, 53, true)))
191-
convert(::Type{BigFloat}, x::Basic) = convert(BigFloat, N(evalf(x, precision(BigFloat), true)))
192-
convert(::Type{Complex{Float64}}, x::Basic) = convert(Complex{Float64}, N(evalf(x, 53, false)))
193-
convert(::Type{Complex{BigFloat}}, x::Basic) = convert(Complex{BigFloat}, N(evalf(x, precision(BigFloat), false)))
194-
convert(::Type{Number}, x::Basic) = x
195-
convert(::Type{T}, x::Basic) where {T <: Real} = convert(T, N(x))
196-
convert(::Type{Complex{T}}, x::Basic) where {T <: Real} = convert(Complex{T}, N(x))
197-
198-
# Constructors no longer fall back to `convert` methods
199-
Base.Int64(x::Basic) = convert(Int64, x)
200-
Base.Int32(x::Basic) = convert(Int32, x)
201-
Base.Float32(x::Basic) = convert(Float32, x)
202-
Base.Float64(x::Basic) = convert(Float64, x)
203-
Base.BigInt(x::Basic) = convert(BigInt, x)
204-
Base.Real(x::Basic) = convert(Real, x)
205222

206223
## For generic programming in Julia
207224
float(x::Basic) = float(N(x))
@@ -265,7 +282,9 @@ trunc(::Type{T},x::Basic, args...) where {T <: Integer} = convert(T, trunc(x,arg
265282
round(x::Basic; kwargs...) = Basic(round(N(x); kwargs...))
266283
round(::Type{T},x::Basic; kwargs...) where {T <: Integer} = convert(T, round(x; kwargs...))
267284

285+
prec(x::Basic) = prec(BasicType(x))
268286
prec(x::BasicType{Val{:RealMPFR}}) = ccall((:real_mpfr_get_prec, libsymengine), Clong, (Ref{Basic},), x)
287+
prec(::BasicType) = throw(ArgumentError("Method not applicable"))
269288

270289
# eps
271290
eps(x::Basic) = eps(BasicType(x))
@@ -276,3 +295,26 @@ eps(::Type{BasicType{Val{:RealDouble}}}) = 2^-52
276295
eps(::Type{BasicType{Val{:ComplexDouble}}}) = 2^-52
277296
eps(x::BasicType{Val{:RealMPFR}}) = evalf(Basic(2), prec(x), true) ^ (-prec(x)+1)
278297
eps(x::BasicType{Val{:ComplexMPFR}}) = eps(real(x))
298+
299+
## convert from BasicType
300+
function convert(::Type{BigInt}, b::BasicType{Val{:Integer}})
301+
_convert(BigInt, Basic(b))
302+
end
303+
304+
function convert(::Type{BigFloat}, b::BasicType{Val{:RealMPFR}})
305+
_convert(BigInt, Basic(b))
306+
end
307+
308+
function convert(::Type{Cdouble}, b::BasicType{Val{:RealDouble}})
309+
_convert(Cdouble, Basic(b))
310+
end
311+
312+
## real/imag for BasicType
313+
function real(b::BasicComplexNumber)
314+
_real(Basic(b))
315+
end
316+
317+
function imag(b::BasicComplexNumber)
318+
_imag(Basic(b))
319+
end
320+
## end deprecate

src/types.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,13 @@ function _get_symengine_classes()
122122
end
123123

124124
const symengine_classes = _get_symengine_classes()
125+
const symengine_classes_val = [Val(c) for c in SymEngine.symengine_classes]
126+
const symengine_classes_val_type = [Val{c} for c in SymEngine.symengine_classes]
125127

126128
"Get SymEngine class of an object (e.g. 1=>:Integer, 1//2 =:Rational, sin(x) => :Sin, ..."
127129
get_symengine_class(s::Basic) = symengine_classes[get_type(s) + 1]
130+
get_symengine_class_val(s::Basic) = symengine_classes_val[get_type(s) + 1]
131+
get_symengine_class_val_type(s::Basic) = symengine_classes_val_type[get_type(s) + 1]
128132

129133

130134
## Construct symbolic objects
@@ -221,8 +225,9 @@ SymbolicType = Union{Basic, BasicType}
221225
convert(::Type{Basic}, x::BasicType) = x.x
222226
Basic(x::BasicType) = x.x
223227

224-
BasicType(val::Basic) = BasicType{Val{get_symengine_class(val)}}(val)
225-
convert(::Type{BasicType{T}}, val::Basic) where {T} = BasicType{Val{get_symengine_class(val)}}(val)
228+
BasicType(val::Basic) = BasicType{get_symengine_class_val_type(val)}(val)
229+
convert(::Type{BasicType{T}}, val::Basic) where {T} =
230+
BasicType{get_symengine_class_val_type(val)}(val)
226231
# Needed for julia v0.4.7
227232
convert(::Type{T}, x::Basic) where {T<:BasicType} = BasicType(x)
228233

@@ -263,6 +268,14 @@ BasicTrigFunction = Union{[SymEngine.BasicType{Val{i}} for i in trig_types]...}
263268

264269

265270
###
271+
272+
"Is expression constant"
273+
function is_constant(ex::Basic)
274+
syms = CSetBasic()
275+
ccall((:basic_free_symbols, libsymengine), Nothing, (Ref{Basic}, Ptr{Cvoid}), ex, syms.ptr)
276+
Base.length(syms) == 0
277+
end
278+
266279
"Is expression a symbol"
267280
function is_symbol(x::SymbolicType)
268281
res = ccall((:is_a_Symbol, libsymengine), Cuint, (Ref{Basic},), x)

0 commit comments

Comments
 (0)