Skip to content

Commit d5cacb3

Browse files
committed
rework JumpProblem
1 parent cf0f050 commit d5cacb3

File tree

2 files changed

+67
-49
lines changed

2 files changed

+67
-49
lines changed

src/jumps.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,27 @@ function JumpSet(vjs, cjs, rj, majv::Vector{T}) where {T <: MassActionJump}
527527
end
528528

529529
@inline get_num_majumps(jset::JumpSet) = get_num_majumps(jset.massaction_jump)
530+
@inline num_majumps(jset::JumpSet) = get_num_majumps(jset)
531+
532+
@inline function num_crjs(jset::JumpSet)
533+
(jset.constant_jumps !== nothing) ? length(jset.constant_jumps) : 0
534+
end
535+
536+
@inline function num_vrjs(jset::JumpSet)
537+
(jset.variable_jumps !== nothing) ? length(jset.variable_jumps) : 0
538+
end
539+
540+
@inline function num_bndvrjs(jset::JumpSet)
541+
(jset.variable_jumps !== nothing) ? count(isbounded, jset.variable_jumps) : 0
542+
end
543+
544+
@inline function num_continvrjs(jset::JumpSet)
545+
(jset.variable_jumps !== nothing) ? count(!isbounded, jset.variable_jumps) : 0
546+
end
547+
548+
num_jumps(jset::JumpSet) = num_majumps(jset) + num_crjs(jset) + num_vrjs(jset)
549+
num_discretejumps(jset::JumpSet) = num_majumps(jset) + num_crjs(jset) + num_bndvrjs(jset)
550+
num_cdiscretejumps(jset::JumpSet) = num_majumps(jset) + num_crjs(jset)
530551

531552
@inline split_jumps(vj, cj, rj, maj) = vj, cj, rj, maj
532553
@inline function split_jumps(vj, cj, rj, maj, v::VariableRateJump, args...)

src/problem.jl

Lines changed: 46 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ $(FIELDS)
5050
the jump occurs.
5151
- `spatial_system`, for spatial problems the underlying spatial structure.
5252
- `hopping_constants`, for spatial problems the spatial transition rate coefficients.
53+
- `use_vrj_bounds = true`, set to false to disable handling bounded `VariableRateJump`s
54+
with a supporting aggregator (such as `Coevolve`). They will then be handled via the
55+
continuous integration interace, and treated like general `VariableRateJump`s.
5356
5457
Please see the [tutorial
5558
page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/) in the
@@ -166,7 +169,7 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS
166169
(false, true) : (true, true),
167170
rng = DEFAULT_RNG, scale_rates = true, useiszero = true,
168171
spatial_system = nothing, hopping_constants = nothing,
169-
callback = nothing, kwargs...)
172+
callback = nothing, use_vrj_bounds = true, kwargs...)
170173

171174
# initialize the MassActionJump rate constants with the user parameters
172175
if using_params(jumps.massaction_jump)
@@ -182,61 +185,55 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS
182185

183186
## Spatial jumps handling
184187
if spatial_system !== nothing && hopping_constants !== nothing &&
185-
!is_spatial(aggregator) # check if need to flatten
188+
!is_spatial(aggregator)
186189
prob, maj = flatten(maj, prob, spatial_system, hopping_constants; kwargs...)
187190
end
188191

189-
## Constant and variable rate handling
192+
if is_spatial(aggregator)
193+
(num_crjs(jumps) == num_vrjs(jumps) == 0) || error("Spatial aggregators only support MassActionJumps currently.")
194+
kwargs = merge((; hopping_constants, spatial_system), kwargs)
195+
end
196+
197+
ndiscjumps = get_num_majumps(maj) + num_crjs(jumps)
198+
199+
# separate bounded variable rate jumps *if* the aggregator can use them
200+
if use_vrj_bounds && supports_variablerates(aggregator) && (num_bndvrjs(jumps) > 0)
201+
bvrjs = filter(isbounded, jumps.variable_jumps)
202+
cvrjs = filter(!isbounded, jumps.variable_jumps)
203+
kwargs = merge((; variable_jumps = bvrjs), kwargs)
204+
ndiscjumps += length(bvrjs)
205+
else
206+
bvrjs = nothing
207+
cvrjs = jumps.variable_jumps
208+
end
209+
190210
t, end_time, u = prob.tspan[1], prob.tspan[2], prob.u0
191211

192-
if length(jumps.variable_jumps) == 0 && (length(jumps.constant_jumps) == 0) &&
193-
(maj === nothing) && !is_spatial(aggregator)
194-
# check if there are no jumps
195-
new_prob = prob
196-
variable_jump_callback = CallbackSet()
197-
cont_agg = JumpSet().variable_jumps
212+
# handle majs, crjs, and bounded vrjs
213+
if (ndiscjumps == 0) && !is_spatial(aggregator)
198214
disc_agg = nothing
199215
constant_jump_callback = CallbackSet()
200-
elseif supports_variablerates(aggregator)
201-
new_prob = prob
216+
else
202217
disc_agg = aggregate(aggregator, u, prob.p, t, end_time, jumps.constant_jumps, maj,
203-
save_positions, rng; variable_jumps = jumps.variable_jumps,
204-
kwargs...)
218+
save_positions, rng; kwargs...)
205219
constant_jump_callback = DiscreteCallback(disc_agg)
220+
end
221+
222+
# handle any remaining vrjs
223+
if length(cvrjs) > 0
224+
new_prob = extend_problem(prob, cvrjs; rng)
225+
variable_jump_callback = build_variable_callback(CallbackSet(), 0, cvrjs...; rng)
226+
cont_agg = cvrjs
227+
else
228+
new_prob = prob
206229
variable_jump_callback = CallbackSet()
207230
cont_agg = JumpSet().variable_jumps
208-
else
209-
# the fallback is to handle each jump type separately
210-
if (length(jumps.constant_jumps) == 0) && (maj === nothing) &&
211-
!is_spatial(aggregator)
212-
disc_agg = nothing
213-
constant_jump_callback = CallbackSet()
214-
else
215-
disc_agg = aggregate(aggregator, u, prob.p, t, end_time, jumps.constant_jumps,
216-
maj,
217-
save_positions, rng; spatial_system = spatial_system,
218-
hopping_constants = hopping_constants, kwargs...)
219-
constant_jump_callback = DiscreteCallback(disc_agg)
220-
end
221-
222-
if length(jumps.variable_jumps) > 0 && !is_spatial(aggregator)
223-
new_prob = extend_problem(prob, jumps; rng = rng)
224-
variable_jump_callback = build_variable_callback(CallbackSet(), 0,
225-
jumps.variable_jumps...;
226-
rng = rng)
227-
cont_agg = jumps.variable_jumps
228-
else
229-
new_prob = prob
230-
variable_jump_callback = CallbackSet()
231-
cont_agg = JumpSet().variable_jumps
232-
end
233231
end
234232

235233
jump_cbs = CallbackSet(constant_jump_callback, variable_jump_callback)
236-
237234
iip = isinplace_jump(prob, jumps.regular_jump)
238-
239235
solkwargs = make_kwarg(; callback)
236+
240237
JumpProblem{iip, typeof(new_prob), typeof(aggregator),
241238
typeof(jump_cbs), typeof(disc_agg),
242239
typeof(cont_agg),
@@ -248,7 +245,7 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS
248245
end
249246

250247
function extend_problem(prob::DiffEqBase.AbstractDiscreteProblem, jumps; rng = DEFAULT_RNG)
251-
error("VariableRateJumps require a continuous problem, like an ODE/SDE/DDE/DAE problem.")
248+
error("General `VariableRateJump`s require a continuous problem, like an ODE/SDE/DDE/DAE problem. To use a `DiscreteProblem` bounded `VariableRateJump`s must be used. See the JumpProcesses docs.")
252249
end
253250

254251
function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps; rng = DEFAULT_RNG)
@@ -257,19 +254,19 @@ function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps; rng = DEFAUL
257254
jump_f = let _f = _f
258255
function jump_f(du::ExtendedJumpArray, u::ExtendedJumpArray, p, t)
259256
_f(du.u, u.u, p, t)
260-
update_jumps!(du, u, p, t, length(u.u), jumps.variable_jumps...)
257+
update_jumps!(du, u, p, t, length(u.u), jumps...)
261258
end
262259
end
263260
ttype = eltype(prob.tspan)
264261
u0 = ExtendedJumpArray(prob.u0,
265-
[-randexp(rng, ttype) for i in 1:length(jumps.variable_jumps)])
262+
[-randexp(rng, ttype) for i in 1:length(jumps)])
266263
remake(prob, f = ODEFunction{true}(jump_f), u0 = u0)
267264
end
268265

269266
function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAULT_RNG)
270267
function jump_f(du, u, p, t)
271268
prob.f(du.u, u.u, p, t)
272-
update_jumps!(du, u, p, t, length(u.u), jumps.variable_jumps...)
269+
update_jumps!(du, u, p, t, length(u.u), jumps...)
273270
end
274271

275272
if prob.noise_rate_prototype === nothing
@@ -284,30 +281,30 @@ function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAUL
284281

285282
ttype = eltype(prob.tspan)
286283
u0 = ExtendedJumpArray(prob.u0,
287-
[-randexp(rng, ttype) for i in 1:length(jumps.variable_jumps)])
284+
[-randexp(rng, ttype) for i in 1:length(jumps)])
288285
remake(prob, f = SDEFunction{true}(jump_f, jump_g), g = jump_g, u0 = u0)
289286
end
290287

291288
function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps; rng = DEFAULT_RNG)
292289
jump_f = function (du, u, h, p, t)
293290
prob.f(du.u, u.u, h, p, t)
294-
update_jumps!(du, u, p, t, length(u.u), jumps.variable_jumps...)
291+
update_jumps!(du, u, p, t, length(u.u), jumps...)
295292
end
296293
ttype = eltype(prob.tspan)
297294
u0 = ExtendedJumpArray(prob.u0,
298-
[-randexp(rng, ttype) for i in 1:length(jumps.variable_jumps)])
295+
[-randexp(rng, ttype) for i in 1:length(jumps)])
299296
remake(prob, f = DDEFunction{true}(jump_f), u0 = u0)
300297
end
301298

302299
# Not sure if the DAE one is correct: Should be a residual of sorts
303300
function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps; rng = DEFAULT_RNG)
304301
jump_f = function (out, du, u, p, t)
305302
prob.f(out.u, du.u, u.u, t)
306-
update_jumps!(du, u, t, length(u.u), jumps.variable_jumps...)
303+
update_jumps!(du, u, t, length(u.u), jumps...)
307304
end
308305
ttype = eltype(prob.tspan)
309306
u0 = ExtendedJumpArray(prob.u0,
310-
[-randexp(rng, ttype) for i in 1:length(jumps.variable_jumps)])
307+
[-randexp(rng, ttype) for i in 1:length(jumps)])
311308
remake(prob, f = DAEFunction{true}(jump_f), u0 = u0)
312309
end
313310

0 commit comments

Comments
 (0)