@@ -19,12 +19,13 @@ mutable struct CoevolveJumpAggregation{T, S, F1, F2, RNG, GR, PQ} <:
1919 lrates:: F1 # vector of rate lower bound functions
2020 urates:: F1 # vector of rate upper bound functions
2121 rateintervals:: F1 # vector of interval length functions
22+ haslratevec:: Vector{Bool} # vector of whether an lrate was provided for this vrj
2223end
2324
2425function CoevolveJumpAggregation (nj:: Int , njt:: T , et:: T , crs:: Vector{T} , sr:: Nothing ,
2526 maj:: S , rs:: F1 , affs!:: F2 , sps:: Tuple{Bool, Bool} ,
2627 rng:: RNG ; u:: U , dep_graph = nothing , lrates, urates,
27- rateintervals) where {T, S, F1, F2, RNG, U}
28+ rateintervals, haslratevec ) where {T, S, F1, F2, RNG, U}
2829 if dep_graph === nothing
2930 if (get_num_majumps (maj) == 0 ) || ! isempty (rs)
3031 error (" To use Coevolve a dependency graph between jumps must be supplied." )
@@ -47,7 +48,7 @@ function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Not
4748 pq = MutableBinaryMinHeap {T} ()
4849 CoevolveJumpAggregation{T, S, F1, F2, RNG, typeof (dg),
4950 typeof (pq)}(nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, rng,
50- dg, pq, lrates, urates, rateintervals)
51+ dg, pq, lrates, urates, rateintervals, haslratevec )
5152end
5253
5354# creating the JumpAggregation structure (tuple-based variable jumps)
@@ -57,37 +58,46 @@ function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps,
5758 AffectWrapper = FunctionWrappers. FunctionWrapper{Nothing, Tuple{Any}}
5859 RateWrapper = FunctionWrappers. FunctionWrapper{typeof (t),
5960 Tuple{typeof (u), typeof (p), typeof (t)}}
60- affects! = Vector {AffectWrapper} ()
61- rates = Vector {RateWrapper} ()
62- lrates = Vector {RateWrapper} ()
63- urates = Vector {RateWrapper} ()
64- rateintervals = Vector {RateWrapper} ()
65-
66- if (constant_jumps != = nothing ) && ! isempty (constant_jumps)
67- append! (affects!,
68- [AffectWrapper ((integrator) -> (j. affect! (integrator); nothing ))
69- for j in constant_jumps])
70- append! (urates, [RateWrapper (j. rate) for j in constant_jumps])
61+
62+ ncrjs = (constant_jumps === nothing ) ? 0 : length (constant_jumps)
63+ nvrjs = (variable_jumps === nothing ) ? 0 : length (variable_jumps)
64+ nrjs = ncrjs + nvrjs
65+ affects! = Vector {AffectWrapper} (undef, nrjs)
66+ rates = Vector {RateWrapper} (undef, nvrjs)
67+ lrates = similar (rates)
68+ urates = similar (rates)
69+ rateintervals = similar (rates)
70+ haslratevec = zeros (Bool, nvrjs)
71+
72+ idx = 1
73+ if constant_jumps != = nothing
74+ for crj in constant_jumps
75+ affects![idx] = AffectWrapper (integ -> (crj. affect! (integ); nothing ))
76+ urates[idx] = RateWrapper (crj. rate)
77+ idx += 1
78+ end
7179 end
7280
73- if (variable_jumps != = nothing ) && ! isempty (variable_jumps)
74- append! (affects!,
75- [AffectWrapper ((integrator) -> (j. affect! (integrator); nothing ))
76- for j in variable_jumps])
77- append! (rates, [RateWrapper (j. rate) for j in variable_jumps])
78- append! (lrates, [RateWrapper (j. lrate) for j in variable_jumps])
79- append! (urates, [RateWrapper (j. urate) for j in variable_jumps])
80- append! (rateintervals, [RateWrapper (j. rateinterval) for j in variable_jumps])
81+ if variable_jumps != = nothing
82+ for (i, vrj) in enumerate (variable_jumps)
83+ affects![idx] = AffectWrapper (integ -> (vrj. affect! (integ); nothing ))
84+ urates[idx] = RateWrapper (vrj. urate)
85+ idx += 1
86+ rates[i] = RateWrapper (vrj. rate)
87+ rateintervals[i] = RateWrapper (vrj. rateinterval)
88+ haslratevec[i] = haslrate (vrj)
89+ lrates[i] = haslratevec[i] ? RateWrapper (vrj. lrate) : RateWrapper (nullrate)
90+ end
8191 end
8292
83- num_jumps = get_num_majumps (ma_jumps) + length (urates)
93+ num_jumps = get_num_majumps (ma_jumps) + nrjs
8494 cur_rates = Vector {typeof(t)} (undef, num_jumps)
8595 sum_rate = nothing
8696 next_jump = 0
8797 next_jump_time = typemax (t)
8898 CoevolveJumpAggregation (next_jump, next_jump_time, end_time, cur_rates, sum_rate,
8999 ma_jumps, rates, affects!, save_positions, rng;
90- u, dep_graph, lrates, urates, rateintervals)
100+ u, dep_graph, lrates, urates, rateintervals, haslratevec )
91101end
92102
93103# set up a new simulation and calculate the first jump / jump time
146156end
147157
148158function next_time (p:: CoevolveJumpAggregation{T} , u, params, t, i, tstop:: T ) where {T}
149- @unpack rng = p
159+ @unpack rng, haslratevec = p
150160 num_majumps = get_num_majumps (p. ma_jumps)
151161 num_cjumps = length (p. urates) - length (p. rates)
152162 uidx = i - num_majumps
@@ -171,7 +181,7 @@ function next_time(p::CoevolveJumpAggregation{T}, u, params, t, i, tstop::T) whe
171181 end
172182 (_t >= tstop) && break
173183
174- lrate = get_lrate (p, lidx, u, params, t)
184+ lrate = haslratevec[lidx] ? get_lrate (p, lidx, u, params, t) : zero ( t)
175185 if lrate < urate
176186 # when the lower and upper bound are the same, then v < 1 = lrate / urate = urate / urate
177187 v = rand (rng) * urate
0 commit comments