@@ -50,6 +50,9 @@ $(FIELDS)
5050 the jump occurs.
5151- `spatial_system`, for spatial problems the underlying spatial structure.
5252- `hopping_constants`, for spatial problems the spatial transition rate coefficients.
53+ - `use_vrj_bounds = true`, set to false to disable handling bounded `VariableRateJump`s
54+ with a supporting aggregator (such as `Coevolve`). They will then be handled via the
55+ continuous integration interace, and treated like general `VariableRateJump`s.
5356
5457Please see the [tutorial
5558page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/) in the
@@ -166,7 +169,7 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS
166169 (false , true ) : (true , true ),
167170 rng = DEFAULT_RNG, scale_rates = true , useiszero = true ,
168171 spatial_system = nothing , hopping_constants = nothing ,
169- callback = nothing , kwargs... )
172+ callback = nothing , use_vrj_bounds = true , kwargs... )
170173
171174 # initialize the MassActionJump rate constants with the user parameters
172175 if using_params (jumps. massaction_jump)
@@ -182,61 +185,55 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS
182185
183186 # # Spatial jumps handling
184187 if spatial_system != = nothing && hopping_constants != = nothing &&
185- ! is_spatial (aggregator) # check if need to flatten
188+ ! is_spatial (aggregator)
186189 prob, maj = flatten (maj, prob, spatial_system, hopping_constants; kwargs... )
187190 end
188191
189- # # Constant and variable rate handling
192+ if is_spatial (aggregator)
193+ (num_crjs (jumps) == num_vrjs (jumps) == 0 ) || error (" Spatial aggregators only support MassActionJumps currently." )
194+ kwargs = merge ((; hopping_constants, spatial_system), kwargs)
195+ end
196+
197+ ndiscjumps = get_num_majumps (maj) + num_crjs (jumps)
198+
199+ # separate bounded variable rate jumps *if* the aggregator can use them
200+ if use_vrj_bounds && supports_variablerates (aggregator) && (num_bndvrjs (jumps) > 0 )
201+ bvrjs = filter (isbounded, jumps. variable_jumps)
202+ cvrjs = filter (! isbounded, jumps. variable_jumps)
203+ kwargs = merge ((; variable_jumps = bvrjs), kwargs)
204+ ndiscjumps += length (bvrjs)
205+ else
206+ bvrjs = nothing
207+ cvrjs = jumps. variable_jumps
208+ end
209+
190210 t, end_time, u = prob. tspan[1 ], prob. tspan[2 ], prob. u0
191211
192- if length (jumps. variable_jumps) == 0 && (length (jumps. constant_jumps) == 0 ) &&
193- (maj === nothing ) && ! is_spatial (aggregator)
194- # check if there are no jumps
195- new_prob = prob
196- variable_jump_callback = CallbackSet ()
197- cont_agg = JumpSet (). variable_jumps
212+ # handle majs, crjs, and bounded vrjs
213+ if (ndiscjumps == 0 ) && ! is_spatial (aggregator)
198214 disc_agg = nothing
199215 constant_jump_callback = CallbackSet ()
200- elseif supports_variablerates (aggregator)
201- new_prob = prob
216+ else
202217 disc_agg = aggregate (aggregator, u, prob. p, t, end_time, jumps. constant_jumps, maj,
203- save_positions, rng; variable_jumps = jumps. variable_jumps,
204- kwargs... )
218+ save_positions, rng; kwargs... )
205219 constant_jump_callback = DiscreteCallback (disc_agg)
220+ end
221+
222+ # handle any remaining vrjs
223+ if length (cvrjs) > 0
224+ new_prob = extend_problem (prob, cvrjs; rng)
225+ variable_jump_callback = build_variable_callback (CallbackSet (), 0 , cvrjs... ; rng)
226+ cont_agg = cvrjs
227+ else
228+ new_prob = prob
206229 variable_jump_callback = CallbackSet ()
207230 cont_agg = JumpSet (). variable_jumps
208- else
209- # the fallback is to handle each jump type separately
210- if (length (jumps. constant_jumps) == 0 ) && (maj === nothing ) &&
211- ! is_spatial (aggregator)
212- disc_agg = nothing
213- constant_jump_callback = CallbackSet ()
214- else
215- disc_agg = aggregate (aggregator, u, prob. p, t, end_time, jumps. constant_jumps,
216- maj,
217- save_positions, rng; spatial_system = spatial_system,
218- hopping_constants = hopping_constants, kwargs... )
219- constant_jump_callback = DiscreteCallback (disc_agg)
220- end
221-
222- if length (jumps. variable_jumps) > 0 && ! is_spatial (aggregator)
223- new_prob = extend_problem (prob, jumps; rng = rng)
224- variable_jump_callback = build_variable_callback (CallbackSet (), 0 ,
225- jumps. variable_jumps... ;
226- rng = rng)
227- cont_agg = jumps. variable_jumps
228- else
229- new_prob = prob
230- variable_jump_callback = CallbackSet ()
231- cont_agg = JumpSet (). variable_jumps
232- end
233231 end
234232
235233 jump_cbs = CallbackSet (constant_jump_callback, variable_jump_callback)
236-
237234 iip = isinplace_jump (prob, jumps. regular_jump)
238-
239235 solkwargs = make_kwarg (; callback)
236+
240237 JumpProblem{iip, typeof (new_prob), typeof (aggregator),
241238 typeof (jump_cbs), typeof (disc_agg),
242239 typeof (cont_agg),
@@ -248,7 +245,7 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS
248245end
249246
250247function extend_problem (prob:: DiffEqBase.AbstractDiscreteProblem , jumps; rng = DEFAULT_RNG)
251- error (" VariableRateJumps require a continuous problem, like an ODE/SDE/DDE/DAE problem." )
248+ error (" General `VariableRateJump`s require a continuous problem, like an ODE/SDE/DDE/DAE problem. To use a `DiscreteProblem` bounded `VariableRateJump`s must be used. See the JumpProcesses docs ." )
252249end
253250
254251function extend_problem (prob:: DiffEqBase.AbstractODEProblem , jumps; rng = DEFAULT_RNG)
@@ -257,19 +254,19 @@ function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps; rng = DEFAUL
257254 jump_f = let _f = _f
258255 function jump_f (du:: ExtendedJumpArray , u:: ExtendedJumpArray , p, t)
259256 _f (du. u, u. u, p, t)
260- update_jumps! (du, u, p, t, length (u. u), jumps. variable_jumps . .. )
257+ update_jumps! (du, u, p, t, length (u. u), jumps... )
261258 end
262259 end
263260 ttype = eltype (prob. tspan)
264261 u0 = ExtendedJumpArray (prob. u0,
265- [- randexp (rng, ttype) for i in 1 : length (jumps. variable_jumps )])
262+ [- randexp (rng, ttype) for i in 1 : length (jumps)])
266263 remake (prob, f = ODEFunction {true} (jump_f), u0 = u0)
267264end
268265
269266function extend_problem (prob:: DiffEqBase.AbstractSDEProblem , jumps; rng = DEFAULT_RNG)
270267 function jump_f (du, u, p, t)
271268 prob. f (du. u, u. u, p, t)
272- update_jumps! (du, u, p, t, length (u. u), jumps. variable_jumps . .. )
269+ update_jumps! (du, u, p, t, length (u. u), jumps... )
273270 end
274271
275272 if prob. noise_rate_prototype === nothing
@@ -284,30 +281,30 @@ function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAUL
284281
285282 ttype = eltype (prob. tspan)
286283 u0 = ExtendedJumpArray (prob. u0,
287- [- randexp (rng, ttype) for i in 1 : length (jumps. variable_jumps )])
284+ [- randexp (rng, ttype) for i in 1 : length (jumps)])
288285 remake (prob, f = SDEFunction {true} (jump_f, jump_g), g = jump_g, u0 = u0)
289286end
290287
291288function extend_problem (prob:: DiffEqBase.AbstractDDEProblem , jumps; rng = DEFAULT_RNG)
292289 jump_f = function (du, u, h, p, t)
293290 prob. f (du. u, u. u, h, p, t)
294- update_jumps! (du, u, p, t, length (u. u), jumps. variable_jumps . .. )
291+ update_jumps! (du, u, p, t, length (u. u), jumps... )
295292 end
296293 ttype = eltype (prob. tspan)
297294 u0 = ExtendedJumpArray (prob. u0,
298- [- randexp (rng, ttype) for i in 1 : length (jumps. variable_jumps )])
295+ [- randexp (rng, ttype) for i in 1 : length (jumps)])
299296 remake (prob, f = DDEFunction {true} (jump_f), u0 = u0)
300297end
301298
302299# Not sure if the DAE one is correct: Should be a residual of sorts
303300function extend_problem (prob:: DiffEqBase.AbstractDAEProblem , jumps; rng = DEFAULT_RNG)
304301 jump_f = function (out, du, u, p, t)
305302 prob. f (out. u, du. u, u. u, t)
306- update_jumps! (du, u, t, length (u. u), jumps. variable_jumps . .. )
303+ update_jumps! (du, u, t, length (u. u), jumps... )
307304 end
308305 ttype = eltype (prob. tspan)
309306 u0 = ExtendedJumpArray (prob. u0,
310- [- randexp (rng, ttype) for i in 1 : length (jumps. variable_jumps )])
307+ [- randexp (rng, ttype) for i in 1 : length (jumps)])
311308 remake (prob, f = DAEFunction {true} (jump_f), u0 = u0)
312309end
313310
0 commit comments