Skip to content

Commit 39a022c

Browse files
author
xiaoming
committed
add test 1.8 and minor fix
1 parent c391f6f commit 39a022c

File tree

2 files changed

+110
-103
lines changed

2 files changed

+110
-103
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ jobs:
2020
version:
2121
- '1.6'
2222
- '1.7'
23+
- '1.8'
2324
os:
2425
- ubuntu-latest
2526
arch:

src/delayproblem.jl

Lines changed: 109 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -71,44 +71,45 @@ delaysets = DelayJumpSet(delay_trigger,delay_complete,delay_interrupt)
7171
7272
"""
7373
mutable struct DelayJumpSet{T1<:Union{Function,Vector},T2<:Union{Function,Vector},T3<:Function}
74-
"""reactions in the Markovian part that trigger the change of the state of the delay channels or/and the state of the reactants upon initiation."""
75-
delay_trigger::Dict{Int,T1}
76-
"""reactions in the Markovian part that change the state of the delay channels or/and the state of the reactants in the middle of on-going delay reactions."""
77-
delay_complete::Dict{Int,T2}
78-
"""reactions that are initiated by delay trigger reactions and change the state of the delay channels or/and the state of the reactants upon completion."""
79-
delay_interrupt::Dict{Int,T3}
80-
"""collection of indices of reactions that can interrupt the delay reactions. of `delay_trigger`."""
81-
delay_trigger_set::Vector{Int}
82-
"""collection of indices of `delay_interrupt`."""
83-
delay_interrupt_set::Vector{Int}
74+
"""reactions in the Markovian part that trigger the change of the state of the delay channels or/and the state of the reactants upon initiation."""
75+
delay_trigger::Dict{Int,T1}
76+
"""reactions in the Markovian part that change the state of the delay channels or/and the state of the reactants in the middle of on-going delay reactions."""
77+
delay_complete::Dict{Int,T2}
78+
"""reactions that are initiated by delay trigger reactions and change the state of the delay channels or/and the state of the reactants upon completion."""
79+
delay_interrupt::Dict{Int,T3}
80+
"""collection of indices of reactions that can interrupt the delay reactions. of `delay_trigger`."""
81+
delay_trigger_set::Vector{Int}
82+
"""collection of indices of `delay_interrupt`."""
83+
delay_interrupt_set::Vector{Int}
8484
end
85-
function DelayJumpSet(delay_trigger::Dict,delay_complete::Dict,delay_interrupt::Dict)
86-
delay_trigger_, delay_complete_, delay_interrupt_ = convert_delayset.([delay_trigger, delay_complete, delay_interrupt])
87-
DelayJumpSet(delay_trigger_, delay_complete_, delay_interrupt_, collect(keys(delay_trigger_)), collect(keys(delay_interrupt_)))
85+
function DelayJumpSet(delay_trigger::Dict, delay_complete::Dict, delay_interrupt::Dict)
86+
delay_trigger_, delay_complete_, delay_interrupt_ = convert_delayset.([delay_trigger, delay_complete, delay_interrupt])
87+
DelayJumpSet(delay_trigger_, delay_complete_, delay_interrupt_, collect(keys(delay_trigger_)), collect(keys(delay_interrupt_)))
8888
end
8989

90-
convert_delayset(delay_set::Dict) = isempty(delay_set) ? convert(Dict{Int,Function},delay_set) : delay_set
90+
convert_delayset(delay_set::Dict) = isempty(delay_set) ? convert(Dict{Int,Function}, delay_set) : delay_set
9191

9292
#BEGIN DelayJump
93-
mutable struct DelayJumpProblem{iip,P,A,C,J<:Union{Nothing,AbstractJumpAggregator},J2,J3,J4,J5,deType} <: DiffEqBase.AbstractJumpProblem{P,J}
94-
prob::P
95-
aggregator::A
96-
discrete_jump_aggregation::J
97-
jump_callback::C
98-
variable_jumps::J2
99-
regular_jump::J3
100-
massaction_jump::J4
101-
delayjumpsets::J5
102-
de_chan0::deType
103-
save_delay_channel::Bool
93+
mutable struct DelayJumpProblem{iip,P,A,C,J<:Union{Nothing,AbstractJumpAggregator},J2,J3,J4,J5,deType,K} <: DiffEqBase.AbstractJumpProblem{P,J}
94+
prob::P
95+
aggregator::A
96+
discrete_jump_aggregation::J
97+
jump_callback::C
98+
variable_jumps::J2
99+
regular_jump::J3
100+
massaction_jump::J4
101+
delayjumpsets::J5
102+
de_chan0::deType
103+
save_delay_channel::Bool
104+
kwargs::K
104105
end
105106

106-
function DelayJumpProblem(p::P,a::A,dj::J,jc::C,vj::J2,rj::J3,mj::J4,djs::J5,de_chan0::deType,save_delay_channel::Bool) where {P,A,J,C,J2,J3,J4,J5,deType}
107-
if !(typeof(a)<:AbstractDelayAggregatorAlgorithm)
108-
error("To solve DelayJumpProblem, one has to use one of the delay aggregators.")
109-
end
110-
iip = isinplace_jump(p,rj)
111-
DelayJumpProblem{iip,P,A,C,J,J2,J3,J4,J5,deType}(p,a,dj,jc,vj,rj,mj,djs,de_chan0,save_delay_channel)
107+
function DelayJumpProblem(p::P, a::A, dj::J, jc::C, vj::J2, rj::J3, mj::J4, djs::J5, de_chan0::deType, save_delay_channel::Bool, kwargs::K) where {P,A,J,C,J2,J3,J4,J5,deType,K}
108+
if !(typeof(a) <: AbstractDelayAggregatorAlgorithm)
109+
error("To solve DelayJumpProblem, one has to use one of the delay aggregators.")
110+
end
111+
iip = isinplace_jump(p, rj)
112+
DelayJumpProblem{iip,P,A,C,J,J2,J3,J4,J5,deType,K}(p, a, dj, jc, vj, rj, mj, djs, de_chan0, save_delay_channel, kwargs)
112113
end
113114

114115

@@ -136,27 +137,27 @@ end
136137
The initial condition of the delay channel.
137138
"""
138139
function DelayJumpProblem(prob, aggregator::AbstractDelayAggregatorAlgorithm, jumps::JumpSet, delayjumpsets::DelayJumpSet, de_chan0;
139-
save_positions = typeof(prob) <: DiffEqBase.AbstractDiscreteProblem ? (false,true) : (true,true),
140-
rng = Xorshifts.Xoroshiro128Star(rand(UInt64)), scale_rates = false, useiszero = true, save_delay_channel = false, kwargs...)
140+
save_positions=typeof(prob) <: DiffEqBase.AbstractDiscreteProblem ? (false, true) : (true, true),
141+
rng=Xorshifts.Xoroshiro128Star(rand(UInt64)), scale_rates=false, useiszero=true, save_delay_channel=false, callback=nothing, kwargs...)
141142

142143
# initialize the MassActionJump rate constants with the user parameters
143-
if using_params(jumps.massaction_jump)
144+
if using_params(jumps.massaction_jump)
144145
rates = jumps.massaction_jump.param_mapper(prob.p)
145-
maj = MassActionJump(rates, jumps.massaction_jump.reactant_stoch, jumps.massaction_jump.net_stoch,
146-
jumps.massaction_jump.param_mapper; scale_rates=scale_rates, useiszero=useiszero,
147-
nocopy=true)
146+
maj = MassActionJump(rates, jumps.massaction_jump.reactant_stoch, jumps.massaction_jump.net_stoch,
147+
jumps.massaction_jump.param_mapper; scale_rates=scale_rates, useiszero=useiszero,
148+
nocopy=true)
148149
else
149150
maj = jumps.massaction_jump
150151
end
151152

152153

153154
## Constant Rate Handling
154-
t,end_time,u = prob.tspan[1],prob.tspan[2],prob.u0
155+
t, end_time, u = prob.tspan[1], prob.tspan[2], prob.u0
155156
if (typeof(jumps.constant_jumps) <: Tuple{}) && (maj === nothing) # check if there are no jumps
156157
disc = nothing
157158
constant_jump_callback = CallbackSet()
158159
else
159-
disc = aggregate(aggregator,u,prob.p,t,end_time,jumps.constant_jumps,maj,save_positions,rng; kwargs...)
160+
disc = aggregate(aggregator, u, prob.p, t, end_time, jumps.constant_jumps, maj, save_positions, rng; kwargs...)
160161
constant_jump_callback = DiscreteCallback(disc)
161162
end
162163

@@ -167,21 +168,23 @@ function DelayJumpProblem(prob, aggregator::AbstractDelayAggregatorAlgorithm, ju
167168
new_prob = prob
168169
variable_jump_callback = CallbackSet()
169170
else
170-
new_prob = extend_problem(prob,jumps)
171-
variable_jump_callback = build_variable_callback(CallbackSet(),0,jumps.variable_jumps...)
171+
new_prob = extend_problem(prob, jumps)
172+
variable_jump_callback = build_variable_callback(CallbackSet(), 0, jumps.variable_jumps...)
172173
end
173-
callbacks = CallbackSet(constant_jump_callback,variable_jump_callback)
174+
callbacks = CallbackSet(constant_jump_callback, variable_jump_callback)
175+
176+
solkwargs = make_kwarg(; callback)
174177

175178
DelayJumpProblem{iip,typeof(new_prob),typeof(aggregator),typeof(callbacks),
176-
typeof(disc),typeof(jumps.variable_jumps),
177-
typeof(jumps.regular_jump),typeof(maj),typeof(delayjumpsets),typeof(de_chan0)}(
178-
new_prob,aggregator,disc,
179-
callbacks,
180-
jumps.variable_jumps,
181-
jumps.regular_jump, maj, delayjumpsets, de_chan0, save_delay_channel)
179+
typeof(disc),typeof(jumps.variable_jumps),
180+
typeof(jumps.regular_jump),typeof(maj),typeof(delayjumpsets),typeof(de_chan0),typeof(solkwargs)}(
181+
new_prob, aggregator, disc,
182+
callbacks,
183+
jumps.variable_jumps,
184+
jumps.regular_jump, maj, delayjumpsets, de_chan0, save_delay_channel, solkwargs)
182185
end
183186

184-
187+
make_kwarg(; kwargs...) = kwargs
185188
"""
186189
function DelayJumpProblem(js::JumpSystem, prob, aggregator, delayjumpset, de_chan0; kwargs...)
187190
# Fields
@@ -203,34 +206,37 @@ end
203206
204207
The initial condition of the delay channel.
205208
"""
206-
function DelayJumpProblem(js::JumpSystem, prob, aggregator, delayjumpset, de_chan0; scale_rates = false, save_delay_channel = false, kwargs...)
207-
statetoid = Dict(value(state) => i for (i,state) in enumerate(states(js)))
208-
eqs = equations(js)
209-
invttype = prob.tspan[1] === nothing ? Float64 : typeof(1 / prob.tspan[2])
210-
211-
# handling parameter substition and empty param vecs
212-
p = (prob.p isa DiffEqBase.NullParameters || prob.p === nothing) ? Num[] : prob.p
213-
214-
majpmapper = ModelingToolkit.JumpSysMajParamMapper(js, p; jseqs=eqs, rateconsttype=invttype)
215-
majs = isempty(eqs.x[1]) ? nothing : assemble_maj(eqs.x[1], statetoid, majpmapper)
216-
crjs = ConstantRateJump[assemble_crj(js, j, statetoid) for j in eqs.x[2]]
217-
vrjs = VariableRateJump[assemble_vrj(js, j, statetoid) for j in eqs.x[3]]
218-
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
219-
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, majs)
220-
221-
if needs_vartojumps_map(aggregator) || needs_depgraph(aggregator)
222-
jdeps = asgraph(js)
223-
vdeps = variable_dependencies(js)
224-
vtoj = jdeps.badjlist
225-
jtov = vdeps.badjlist
226-
jtoj = needs_depgraph(aggregator) ? eqeq_dependencies(jdeps, vdeps).fadjlist : nothing
227-
dep_graph_delay = dep_gr_delay(delayjumpset, vtoj, length(eqs))
228-
else
229-
vtoj = nothing; jtov = nothing; jtoj = nothing; dep_graph_delay = nothing;
230-
end
209+
function DelayJumpProblem(js::JumpSystem, prob, aggregator, delayjumpset, de_chan0; scale_rates=false, save_delay_channel=false, kwargs...)
210+
statetoid = Dict(value(state) => i for (i, state) in enumerate(states(js)))
211+
eqs = equations(js)
212+
invttype = prob.tspan[1] === nothing ? Float64 : typeof(1 / prob.tspan[2])
213+
214+
# handling parameter substition and empty param vecs
215+
p = (prob.p isa DiffEqBase.NullParameters || prob.p === nothing) ? Num[] : prob.p
216+
217+
majpmapper = ModelingToolkit.JumpSysMajParamMapper(js, p; jseqs=eqs, rateconsttype=invttype)
218+
majs = isempty(eqs.x[1]) ? nothing : assemble_maj(eqs.x[1], statetoid, majpmapper)
219+
crjs = ConstantRateJump[assemble_crj(js, j, statetoid) for j in eqs.x[2]]
220+
vrjs = VariableRateJump[assemble_vrj(js, j, statetoid) for j in eqs.x[3]]
221+
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
222+
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, majs)
223+
224+
if needs_vartojumps_map(aggregator) || needs_depgraph(aggregator)
225+
jdeps = asgraph(js)
226+
vdeps = variable_dependencies(js)
227+
vtoj = jdeps.badjlist
228+
jtov = vdeps.badjlist
229+
jtoj = needs_depgraph(aggregator) ? eqeq_dependencies(jdeps, vdeps).fadjlist : nothing
230+
dep_graph_delay = dep_gr_delay(delayjumpset, vtoj, length(eqs))
231+
else
232+
vtoj = nothing
233+
jtov = nothing
234+
jtoj = nothing
235+
dep_graph_delay = nothing
236+
end
231237

232238

233-
DelayJumpProblem(prob, aggregator, jset, delayjumpset, de_chan0; save_delay_channel=save_delay_channel, dep_graph=jtoj, vartojumps_map=vtoj, jumptovars_map=jtov, scale_rates=scale_rates, dep_graph_delay = dep_graph_delay, nocopy=true, kwargs...)
239+
DelayJumpProblem(prob, aggregator, jset, delayjumpset, de_chan0; save_delay_channel=save_delay_channel, dep_graph=jtoj, vartojumps_map=vtoj, jumptovars_map=jtov, scale_rates=scale_rates, dep_graph_delay=dep_graph_delay, nocopy=true, kwargs...)
234240
end
235241

236242

@@ -242,47 +248,47 @@ function DiffEqBase.remake(thing::DelayJumpProblem; kwargs...)
242248
errmesg = """
243249
DelayJumpProblems can currently only be remade with new u0, de_chan0, p, tspan, delayjumpsets fields, prob fields.
244250
"""
245-
!issubset(keys(kwargs),((:u0,:de_chan0,:p,:tspan,:prob)...,propertynames(thing.delayjumpsets)...)) && error(errmesg)
251+
!issubset(keys(kwargs), ((:u0, :de_chan0, :p, :tspan, :prob)..., propertynames(thing.delayjumpsets)...)) && error(errmesg)
246252

247253
if :prob keys(kwargs)
248254
dprob = DiffEqBase.remake(thing.prob; kwargs...)
249255
# if the parameters were changed we must remake the MassActionJump too
250256
if (:p keys(kwargs)) && JumpProcesses.using_params(thing.massaction_jump)
251-
JumpProcesses.update_parameters!(thing.massaction_jump, dprob.p; kwargs...)
252-
end
257+
JumpProcesses.update_parameters!(thing.massaction_jump, dprob.p; kwargs...)
258+
end
253259
else
254-
any(k -> k in keys(kwargs), (:u0,:p,:tspan)) && error("If remaking a DelayJumpProblem you can not pass both prob and any of u0, p, or tspan.")
260+
any(k -> k in keys(kwargs), (:u0, :p, :tspan)) && error("If remaking a DelayJumpProblem you can not pass both prob and any of u0, p, or tspan.")
255261
dprob = kwargs[:prob]
256262

257263
# we can't know if p was changed, so we must remake the MassActionJump
258264
if JumpProcesses.using_params(thing.massaction_jump)
259265
JumpProcesses.update_parameters!(thing.massaction_jump, dprob.p; kwargs...)
260-
end
266+
end
261267
end
262-
if any(k -> k in keys(kwargs), propertynames(thing.delayjumpsets))
263-
delayjumpsets = update_delayjumpsets(thing.delayjumpsets; kwargs...)
268+
if any(k -> k in keys(kwargs), propertynames(thing.delayjumpsets))
269+
delayjumpsets = update_delayjumpsets(thing.delayjumpsets; kwargs...)
264270
else
265-
delayjumpsets = thing.delayjumpsets
271+
delayjumpsets = thing.delayjumpsets
266272
end
267273
de_chan0 = :de_chan0 keys(kwargs) ? kwargs[:de_chan0] : thing.de_chan0
268274

269275
DelayJumpProblem(dprob, thing.aggregator, thing.discrete_jump_aggregation, thing.jump_callback,
270-
thing.variable_jumps, thing.regular_jump, thing.massaction_jump, delayjumpsets, de_chan0, thing.save_delay_channel)
276+
thing.variable_jumps, thing.regular_jump, thing.massaction_jump, delayjumpsets, de_chan0, thing.save_delay_channel, thing.kwargs)
271277
end
272278

273279
function update_delayjumpsets(delayjumpsets::DelayJumpSet; kwargs...)
274-
@unpack delay_trigger, delay_complete, delay_interrupt = delayjumpsets
275-
for (key, value) in kwargs
276-
exp_key = toexpr(key)
277-
if exp_key == :delay_trigger
278-
delay_trigger = value
279-
elseif exp_key == :delay_complete
280-
delay_complete = value
281-
elseif exp_key == :delay_interrupt
282-
delay_interrupt = value
283-
end
280+
@unpack delay_trigger, delay_complete, delay_interrupt = delayjumpsets
281+
for (key, value) in kwargs
282+
exp_key = toexpr(key)
283+
if exp_key == :delay_trigger
284+
delay_trigger = value
285+
elseif exp_key == :delay_complete
286+
delay_complete = value
287+
elseif exp_key == :delay_interrupt
288+
delay_interrupt = value
284289
end
285-
DelayJumpSet(delay_trigger, delay_complete, delay_interrupt)
290+
end
291+
DelayJumpSet(delay_trigger, delay_complete, delay_interrupt)
286292
end
287293

288294
# function update_delayjumpsets(delayjumpsets::DelayJumpSet; kwargs...)
@@ -299,20 +305,20 @@ end
299305
# return delayjumpsets_
300306
# end
301307

302-
Base.summary(io::IO, prob::DelayJumpProblem) = string(DiffEqBase.parameterless_type(prob)," with problem ",DiffEqBase.parameterless_type(prob.prob)," and aggregator ",typeof(prob.aggregator))
308+
Base.summary(io::IO, prob::DelayJumpProblem) = string(DiffEqBase.parameterless_type(prob), " with problem ", DiffEqBase.parameterless_type(prob.prob), " and aggregator ", typeof(prob.aggregator))
303309
function Base.show(io::IO, mime::MIME"text/plain", A::DelayJumpProblem)
304-
println(io,summary(A))
305-
println(io,"Number of constant rate jumps: ",A.discrete_jump_aggregation === nothing ? 0 : num_constant_rate_jumps(A.discrete_jump_aggregation))
306-
println(io,"Number of variable rate jumps: ",length(A.variable_jumps))
310+
println(io, summary(A))
311+
println(io, "Number of constant rate jumps: ", A.discrete_jump_aggregation === nothing ? 0 : num_constant_rate_jumps(A.discrete_jump_aggregation))
312+
println(io, "Number of variable rate jumps: ", length(A.variable_jumps))
307313
if A.regular_jump !== nothing
308-
println(io,"Have a regular jump")
314+
println(io, "Have a regular jump")
309315
end
310316
if (A.massaction_jump !== nothing) && (get_num_majumps(A.massaction_jump) > 0)
311-
println(io,"Have a mass action jump")
317+
println(io, "Have a mass action jump")
312318
end
313319
if A.delayjumpsets !== nothing
314-
println(io,"Number of delay trigger reactions: ",length(A.delayjumpsets.delay_trigger))
315-
println(io,"Number of delay interrupt reactions: ",length(A.delayjumpsets.delay_interrupt))
320+
println(io, "Number of delay trigger reactions: ", length(A.delayjumpsets.delay_trigger))
321+
println(io, "Number of delay interrupt reactions: ", length(A.delayjumpsets.delay_interrupt))
316322
end
317323
end
318324
#END DelayJump

0 commit comments

Comments
 (0)