Skip to content

Commit 176ca81

Browse files
committed
make lrate optional
1 parent cd7ddf5 commit 176ca81

File tree

3 files changed

+65
-41
lines changed

3 files changed

+65
-41
lines changed

src/aggregators/coevolve.jl

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ mutable struct CoevolveJumpAggregation{T, S, F1, F2, RNG, GR, PQ} <:
1919
lrates::F1 # vector of rate lower bound functions
2020
urates::F1 # vector of rate upper bound functions
2121
rateintervals::F1 # vector of interval length functions
22+
haslratevec::Vector{Bool} # vector of whether an lrate was provided for this vrj
2223
end
2324

2425
function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Nothing,
2526
maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool},
2627
rng::RNG; u::U, dep_graph = nothing, lrates, urates,
27-
rateintervals) where {T, S, F1, F2, RNG, U}
28+
rateintervals, haslratevec) where {T, S, F1, F2, RNG, U}
2829
if dep_graph === nothing
2930
if (get_num_majumps(maj) == 0) || !isempty(rs)
3031
error("To use Coevolve a dependency graph between jumps must be supplied.")
@@ -47,7 +48,7 @@ function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Not
4748
pq = MutableBinaryMinHeap{T}()
4849
CoevolveJumpAggregation{T, S, F1, F2, RNG, typeof(dg),
4950
typeof(pq)}(nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, rng,
50-
dg, pq, lrates, urates, rateintervals)
51+
dg, pq, lrates, urates, rateintervals, haslratevec)
5152
end
5253

5354
# creating the JumpAggregation structure (tuple-based variable jumps)
@@ -57,37 +58,46 @@ function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps,
5758
AffectWrapper = FunctionWrappers.FunctionWrapper{Nothing, Tuple{Any}}
5859
RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t),
5960
Tuple{typeof(u), typeof(p), typeof(t)}}
60-
affects! = Vector{AffectWrapper}()
61-
rates = Vector{RateWrapper}()
62-
lrates = Vector{RateWrapper}()
63-
urates = Vector{RateWrapper}()
64-
rateintervals = Vector{RateWrapper}()
65-
66-
if (constant_jumps !== nothing) && !isempty(constant_jumps)
67-
append!(affects!,
68-
[AffectWrapper((integrator) -> (j.affect!(integrator); nothing))
69-
for j in constant_jumps])
70-
append!(urates, [RateWrapper(j.rate) for j in constant_jumps])
61+
62+
ncrjs = (constant_jumps === nothing) ? 0 : length(constant_jumps)
63+
nvrjs = (variable_jumps === nothing) ? 0 : length(variable_jumps)
64+
nrjs = ncrjs + nvrjs
65+
affects! = Vector{AffectWrapper}(undef, nrjs)
66+
rates = Vector{RateWrapper}(undef, nvrjs)
67+
lrates = similar(rates)
68+
urates = similar(rates)
69+
rateintervals = similar(rates)
70+
haslratevec = zeros(Bool, nvrjs)
71+
72+
idx = 1
73+
if constant_jumps !== nothing
74+
for crj in constant_jumps
75+
affects![idx] = AffectWrapper(integ -> (crj.affect!(integ); nothing))
76+
urates[idx] = RateWrapper(crj.rate)
77+
idx += 1
78+
end
7179
end
7280

73-
if (variable_jumps !== nothing) && !isempty(variable_jumps)
74-
append!(affects!,
75-
[AffectWrapper((integrator) -> (j.affect!(integrator); nothing))
76-
for j in variable_jumps])
77-
append!(rates, [RateWrapper(j.rate) for j in variable_jumps])
78-
append!(lrates, [RateWrapper(j.lrate) for j in variable_jumps])
79-
append!(urates, [RateWrapper(j.urate) for j in variable_jumps])
80-
append!(rateintervals, [RateWrapper(j.rateinterval) for j in variable_jumps])
81+
if variable_jumps !== nothing
82+
for (i, vrj) in enumerate(variable_jumps)
83+
affects![idx] = AffectWrapper(integ -> (vrj.affect!(integ); nothing))
84+
urates[idx] = RateWrapper(vrj.urate)
85+
idx += 1
86+
rates[i] = RateWrapper(vrj.rate)
87+
rateintervals[i] = RateWrapper(vrj.rateinterval)
88+
haslratevec[i] = haslrate(vrj)
89+
lrates[i] = haslratevec[i] ? RateWrapper(vrj.lrate) : RateWrapper(nullrate)
90+
end
8191
end
8292

83-
num_jumps = get_num_majumps(ma_jumps) + length(urates)
93+
num_jumps = get_num_majumps(ma_jumps) + nrjs
8494
cur_rates = Vector{typeof(t)}(undef, num_jumps)
8595
sum_rate = nothing
8696
next_jump = 0
8797
next_jump_time = typemax(t)
8898
CoevolveJumpAggregation(next_jump, next_jump_time, end_time, cur_rates, sum_rate,
8999
ma_jumps, rates, affects!, save_positions, rng;
90-
u, dep_graph, lrates, urates, rateintervals)
100+
u, dep_graph, lrates, urates, rateintervals, haslratevec)
91101
end
92102

93103
# set up a new simulation and calculate the first jump / jump time
@@ -146,7 +156,7 @@ end
146156
end
147157

148158
function next_time(p::CoevolveJumpAggregation{T}, u, params, t, i, tstop::T) where {T}
149-
@unpack rng = p
159+
@unpack rng, haslratevec = p
150160
num_majumps = get_num_majumps(p.ma_jumps)
151161
num_cjumps = length(p.urates) - length(p.rates)
152162
uidx = i - num_majumps
@@ -171,7 +181,7 @@ function next_time(p::CoevolveJumpAggregation{T}, u, params, t, i, tstop::T) whe
171181
end
172182
(_t >= tstop) && break
173183

174-
lrate = get_lrate(p, lidx, u, params, t)
184+
lrate = haslratevec[lidx] ? get_lrate(p, lidx, u, params, t) : zero(t)
175185
if lrate < urate
176186
# when the lower and upper bound are the same, then v < 1 = lrate / urate = urate / urate
177187
v = rand(rng) * urate

src/jumps.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ struct VariableRateJump{R, F, R2, R3, R4, I, T, T2} <: AbstractJump
155155
reltol::T2
156156
end
157157

158+
isbounded(::VariableRateJump) = true
159+
isbounded(::VariableRateJump{R,F,R2,Nothing}) where {R,F,R2} = false
160+
haslrate(::VariableRateJump) = true
161+
haslrate(::VariableRateJump{R, F, Nothing}) where {R,F} = false
158162
nullrate(u, p, t::T) where {T <: Number} = zero(T)
159163

160164
"""
@@ -179,8 +183,8 @@ function VariableRateJump(rate, affect!;
179183
error("`urate` and `rateinterval` must both be `nothing`, or must both be defined.")
180184
end
181185

182-
if (urate !== nothing && lrate === nothing)
183-
lrate = nullrate
186+
if lrate !== nothing
187+
(urate !== nothing) || error("If a lower bound rate, `lrate`, is given than an upper bound rate, `urate`, and rate interval, `rateinterval`, must also be provided.")
184188
end
185189

186190
VariableRateJump(rate, affect!, lrate, urate, rateinterval, idxs, rootfind,

test/hawkes_test.jl

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,22 @@ function hawkes_rate(i::Int, g, h)
3232
return rate
3333
end
3434

35-
function hawkes_jump(i::Int, g, h)
35+
function hawkes_jump(i::Int, g, h; uselrate = true)
3636
rate = hawkes_rate(i, g, h)
37-
lrate(u, p, t) = p[1]
3837
urate = rate
39-
function rateinterval(u, p, t)
40-
_lrate = lrate(u, p, t)
41-
_urate = urate(u, p, t)
42-
return _urate == _lrate ? typemax(t) : 1 / (2 * _urate)
38+
if uselrate
39+
lrate(u, p, t) = p[1]
40+
rateinterval = (u, p, t) -> begin
41+
_lrate = lrate(u, p, t)
42+
_urate = urate(u, p, t)
43+
return _urate == _lrate ? typemax(t) : 1 / (2 * _urate)
44+
end
45+
else
46+
lrate = nothing
47+
rateinterval = (u, p, t) -> begin
48+
_urate = urate(u, p, t)
49+
return 1 / (2 * _urate)
50+
end
4351
end
4452
function affect!(integrator)
4553
push!(h[i], integrator.t)
@@ -48,15 +56,15 @@ function hawkes_jump(i::Int, g, h)
4856
return VariableRateJump(rate, affect!; lrate, urate, rateinterval)
4957
end
5058

51-
function hawkes_jump(u, g, h)
52-
return [hawkes_jump(i, g, h) for i in 1:length(u)]
59+
function hawkes_jump(u, g, h; uselrate = true)
60+
return [hawkes_jump(i, g, h; uselrate) for i in 1:length(u)]
5361
end
5462

5563
function hawkes_problem(p, agg::Coevolve; u = [0.0], tspan = (0.0, 50.0),
5664
save_positions = (false, true),
57-
g = [[1]], h = [[]])
65+
g = [[1]], h = [[]], uselrate = true)
5866
dprob = DiscreteProblem(u, tspan, p)
59-
jumps = hawkes_jump(u, g, h)
67+
jumps = hawkes_jump(u, g, h; uselrate)
6068
jprob = JumpProblem(dprob, agg, jumps...; dep_graph = g, save_positions, rng)
6169
return jprob
6270
end
@@ -68,7 +76,7 @@ end
6876

6977
function hawkes_problem(p, agg; u = [0.0], tspan = (0.0, 50.0),
7078
save_positions = (false, true),
71-
g = [[1]], h = [[]])
79+
g = [[1]], h = [[]], kwargs...)
7280
oprob = ODEProblem(f!, u, tspan, p)
7381
jumps = hawkes_jump(u, g, h)
7482
jprob = JumpProblem(oprob, agg, jumps...; save_positions, rng)
@@ -97,11 +105,13 @@ h = [Float64[]]
97105

98106
Eλ, Varλ = expected_stats_hawkes_problem(p, tspan)
99107

100-
algs = (Direct(), Coevolve())
108+
algs = (Direct(), Coevolve(), Coevolve())
109+
uselrate = zeros(Bool, length(algs))
110+
uselrate[3] = true
101111
Nsims = 250
102112

103-
for alg in algs
104-
jump_prob = hawkes_problem(p, alg; u = u0, tspan, g, h)
113+
for (i, alg) in enumerate(algs)
114+
jump_prob = hawkes_problem(p, alg; u = u0, tspan, g, h, uselrate = uselrate[i])
105115
if typeof(alg) <: Coevolve
106116
stepper = SSAStepper()
107117
else

0 commit comments

Comments
 (0)