Skip to content

Commit e271864

Browse files
committed
L -> rateinterval
1 parent ca46720 commit e271864

File tree

3 files changed

+26
-37
lines changed

3 files changed

+26
-37
lines changed

docs/src/tutorials/discrete_stochastic_example.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -510,28 +510,29 @@ define an infection reaction as a bounded `VariableRateJump`, requiring us to
510510
again provide `rate` and `affect` functions, but also give functions that
511511
calculate an upper-bound on the rate (`urate(u,p,t)`), an optional lower-bound
512512
on the rate (`lrate(u,p,t)`), and a time window over which the bounds are valid
513-
as long as any states these three rates depend on are unchanged (`L(u,p,t)`).
513+
as long as any states these three rates depend on are unchanged
514+
(`rateinterval(u,p,t)`).
514515
The lower- and upper-bounds of the rate should be valid from the time they are
515-
computed `t` until `t + L(u, p, t)`:
516+
computed `t` until `t + rateinterval(u, p, t)`:
516517
517518
```@example tut2
518519
H = zeros(Float64, 10)
519520
rate3(u, p, t) = p[1]*u[1]*u[2] + p[3]*u[1]*sum(exp(-p[4]*(t - _t)) for _t in H)
520521
lrate = rate1 # β*S*I
521522
urate = rate3
522-
L(u, p, t) = 1 / (2*urate(u, p, t))
523+
rateinterval(u, p, t) = 1 / (2*urate(u, p, t))
523524
function affect3!(integrator)
524525
integrator.u[1] -= 1 # S -> S - 1
525526
integrator.u[2] += 1 # I -> I + 1
526527
push!(H, integrator.t)
527528
nothing
528529
end
529-
jump3 = VariableRateJump(rate3, affect3!; lrate=lrate, urate=urate, L=L)
530+
jump3 = VariableRateJump(rate3, affect3!; lrate, urate, rateinterval)
530531
```
531532
Note that here we set the lower bound rate to be the normal SIR infection rate,
532533
and set the upper bound rate equal to the new rate of infection (`rate3`). As
533534
long as `u[1]` and `u[2]` are unchanged by another jump, for any `s` in `[t,t +
534-
L(u,p,t)]` we have that `lrate(u,p,t) <= rate3(u,p,s) <= urate(u,p,t)`.
535+
rateinterval(u,p,t)]` we have that `lrate(u,p,t) <= rate3(u,p,s) <= urate(u,p,t)`.
535536
536537
Next, we redefine the recovery jump's `affect!` such that a random infection is
537538
removed from `H` for every recovery.

src/aggregators/coevolve.jl

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@ mutable struct CoevolveJumpAggregation{T, S, F1, F2, RNG, GR, PQ} <:
1818
pq::PQ # priority queue of next time
1919
lrates::F1 # vector of rate lower bound functions
2020
urates::F1 # vector of rate upper bound functions
21-
Ls::F1 # vector of interval length functions
21+
rateintervals::F1 # vector of interval length functions
2222
end
2323

2424
function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Nothing,
2525
maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool},
26-
rng::RNG; u::U,
27-
dep_graph = nothing,
28-
lrates, urates, Ls) where {T, S, F1, F2, RNG, U}
26+
rng::RNG; u::U, dep_graph = nothing, lrates, urates,
27+
rateintervals) where {T, S, F1, F2, RNG, U}
2928
if dep_graph === nothing
3029
if (get_num_majumps(maj) == 0) || !isempty(rs)
3130
error("To use Coevolve a dependency graph between jumps must be supplied.")
@@ -47,30 +46,22 @@ function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Not
4746

4847
pq = MutableBinaryMinHeap{T}()
4948
CoevolveJumpAggregation{T, S, F1, F2, RNG, typeof(dg),
50-
typeof(pq)
51-
}(nj, nj, njt,
52-
et,
53-
crs, sr, maj,
54-
rs,
55-
affs!, sps,
56-
rng,
57-
dg, pq,
58-
lrates, urates, Ls)
49+
typeof(pq)}(nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, rng,
50+
dg, pq, lrates, urates, rateintervals)
5951
end
6052

6153
# creating the JumpAggregation structure (tuple-based variable jumps)
6254
function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps,
63-
ma_jumps, save_positions, rng;
64-
dep_graph = nothing, variable_jumps = nothing,
65-
kwargs...)
55+
ma_jumps, save_positions, rng; dep_graph = nothing,
56+
variable_jumps = nothing, kwargs...)
6657
AffectWrapper = FunctionWrappers.FunctionWrapper{Nothing, Tuple{Any}}
6758
RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t),
6859
Tuple{typeof(u), typeof(p), typeof(t)}}
6960
affects! = Vector{AffectWrapper}()
7061
rates = Vector{RateWrapper}()
7162
lrates = Vector{RateWrapper}()
7263
urates = Vector{RateWrapper}()
73-
Ls = Vector{RateWrapper}()
64+
rateintervals = Vector{RateWrapper}()
7465

7566
if (constant_jumps !== nothing) && !isempty(constant_jumps)
7667
append!(affects!,
@@ -86,7 +77,7 @@ function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps,
8677
append!(rates, [RateWrapper(j.rate) for j in variable_jumps])
8778
append!(lrates, [RateWrapper(j.lrate) for j in variable_jumps])
8879
append!(urates, [RateWrapper(j.urate) for j in variable_jumps])
89-
append!(Ls, [RateWrapper(j.L) for j in variable_jumps])
80+
append!(rateintervals, [RateWrapper(j.rateinterval) for j in variable_jumps])
9081
end
9182

9283
num_jumps = get_num_majumps(ma_jumps) + length(urates)
@@ -96,9 +87,7 @@ function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps,
9687
next_jump_time = typemax(t)
9788
CoevolveJumpAggregation(next_jump, next_jump_time, end_time, cur_rates, sum_rate,
9889
ma_jumps, rates, affects!, save_positions, rng;
99-
u = u,
100-
dep_graph = dep_graph,
101-
lrates = lrates, urates = urates, Ls = Ls)
90+
u, dep_graph, lrates, urates, rateintervals)
10291
end
10392

10493
# set up a new simulation and calculate the first jump / jump time
@@ -144,8 +133,8 @@ end
144133
@inbounds return p.urates[uidx](u, params, t)
145134
end
146135

147-
@inline function get_L(p::CoevolveJumpAggregation, lidx, u, params, t)
148-
@inbounds return p.Ls[lidx](u, params, t)
136+
@inline function get_rateinterval(p::CoevolveJumpAggregation, lidx, u, params, t)
137+
@inbounds return p.rateintervals[lidx](u, params, t)
149138
end
150139

151140
@inline function get_lrate(p::CoevolveJumpAggregation, lidx, u, params, t)
@@ -173,9 +162,9 @@ function next_time(p::CoevolveJumpAggregation{T}, u, params, t, i, tstop::T) whe
173162
_t = t + s
174163
if lidx > 0
175164
while t < tstop
176-
L = get_L(p, lidx, u, params, t)
177-
if s > L
178-
t = t + L
165+
rateinterval = get_rateinterval(p, lidx, u, params, t)
166+
if s > rateinterval
167+
t = t + rateinterval
179168
urate = get_urate(p, uidx, u, params, t)
180169
s = urate == zero(t) ? typemax(t) : randexp(rng) / urate
181170
_t = t + s

test/hawkes_test.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function hawkes_jump(i::Int, g, h)
3636
rate = hawkes_rate(i, g, h)
3737
lrate(u, p, t) = p[1]
3838
urate = rate
39-
function L(u, p, t)
39+
function rateinterval(u, p, t)
4040
_lrate = lrate(u, p, t)
4141
_urate = urate(u, p, t)
4242
return _urate == _lrate ? typemax(t) : 1 / (2 * _urate)
@@ -45,7 +45,7 @@ function hawkes_jump(i::Int, g, h)
4545
push!(h[i], integrator.t)
4646
integrator.u[i] += 1
4747
end
48-
return VariableRateJump(rate, affect!; lrate = lrate, urate = urate, L = L)
48+
return VariableRateJump(rate, affect!; lrate, urate, rateinterval)
4949
end
5050

5151
function hawkes_jump(u, g, h)
@@ -57,8 +57,7 @@ function hawkes_problem(p, agg::Coevolve; u = [0.0], tspan = (0.0, 50.0),
5757
g = [[1]], h = [[]])
5858
dprob = DiscreteProblem(u, tspan, p)
5959
jumps = hawkes_jump(u, g, h)
60-
jprob = JumpProblem(dprob, agg, jumps...;
61-
dep_graph = g, save_positions = save_positions, rng = rng)
60+
jprob = JumpProblem(dprob, agg, jumps...; dep_graph = g, save_positions, rng)
6261
return jprob
6362
end
6463

@@ -72,7 +71,7 @@ function hawkes_problem(p, agg; u = [0.0], tspan = (0.0, 50.0),
7271
g = [[1]], h = [[]])
7372
oprob = ODEProblem(f!, u, tspan, p)
7473
jumps = hawkes_jump(u, g, h)
75-
jprob = JumpProblem(oprob, agg, jumps...; save_positions = save_positions, rng = rng)
74+
jprob = JumpProblem(oprob, agg, jumps...; save_positions, rng)
7675
return jprob
7776
end
7877

@@ -102,7 +101,7 @@ algs = (Direct(), Coevolve())
102101
Nsims = 250
103102

104103
for alg in algs
105-
jump_prob = hawkes_problem(p, alg; u = u0, tspan = tspan, g = g, h = h)
104+
jump_prob = hawkes_problem(p, alg; u = u0, tspan, g, h)
106105
if typeof(alg) <: Coevolve
107106
stepper = SSAStepper()
108107
else

0 commit comments

Comments
 (0)