@@ -2,7 +2,9 @@ function generate_ODENLStepData(sys::System, u0, p, mm = calculate_massmatrix(sy
22 nlsys, outer_tmp, inner_tmp = inner_nlsystem (sys, mm)
33 state = ProblemState (; u = u0, p)
44 op = Dict ()
5- op[ODE_GAMMA] = one (eltype (u0))
5+ op[ODE_GAMMA[1 ]] = one (eltype (u0))
6+ op[ODE_GAMMA[2 ]] = one (eltype (u0))
7+ op[ODE_GAMMA[3 ]] = one (eltype (u0))
68 op[ODE_C] = zero (eltype (u0))
79 op[outer_tmp] = zeros (eltype (u0), size (outer_tmp))
810 op[inner_tmp] = zeros (eltype (u0), size (inner_tmp))
@@ -11,15 +13,17 @@ function generate_ODENLStepData(sys::System, u0, p, mm = calculate_massmatrix(sy
1113 op[v] = getsym (sys, v)(state)
1214 end
1315 nlprob = NonlinearProblem (nlsys, op; build_initializeprob = false )
16+
17+ subsetidxs = [findfirst (isequal (y),unknowns (sys)) for y in unknowns (nlsys)]
1418 set_gamma_c = setsym (nlsys, (ODE_GAMMA... , ODE_C))
1519 set_outer_tmp = setsym (nlsys, outer_tmp)
1620 set_inner_tmp = setsym (nlsys, inner_tmp)
1721 nlprobmap = getsym (nlsys, unknowns (sys))
1822
19- return SciMLBase. ODENLStepData (nlprob, set_gamma_c, set_outer_tmp, set_inner_tmp, nlprobmap)
23+ return SciMLBase. ODENLStepData (nlprob, subsetidxs, set_gamma_c, set_outer_tmp, set_inner_tmp, nlprobmap)
2024end
2125
22- const ODE_GAMMA = @parameters γ₁ₘₜₖ, γ₂ₘₜₖ
26+ const ODE_GAMMA = @parameters γ₁ₘₜₖ, γ₂ₘₜₖ, γ₃ₘₜₖ
2327const ODE_C = only (@parameters cₘₜₖ)
2428
2529function get_outer_tmp (n:: Int )
@@ -38,19 +42,19 @@ function inner_nlsystem(sys::System, mm)
3842 @assert length (eqs) == N
3943 @assert mm == I || size (mm) == (N, N)
4044 rhss = [eq. rhs for eq in eqs]
41- gamma1, gamma2 = ODE_GAMMA
45+ gamma1, gamma2, gamma3 = ODE_GAMMA
4246 c = ODE_C
4347 outer_tmp = get_outer_tmp (N)
4448 inner_tmp = get_inner_tmp (N)
4549
4650 subrules = Dict ([v => gamma2* v + inner_tmp[i] for (i, v) in enumerate (dvs)])
4751 subrules[t] = t + c
4852 new_rhss = map (Base. Fix2 (fast_substitute, subrules), rhss)
49- new_rhss = mm * dvs - gamma1 .* new_rhss .+ collect (outer_tmp)
53+ new_rhss = collect (outer_tmp) .+ gamma1 .* new_rhss .- gamma3 * mm * dvs
5054 new_eqs = [0 ~ rhs for rhs in new_rhss]
5155
5256 new_dvs = unknowns (sys)
53- new_ps = [parameters (sys); [gamma1, gamma2, c, inner_tmp, outer_tmp]]
57+ new_ps = [parameters (sys); [gamma1, gamma2, gamma3, c, inner_tmp, outer_tmp]]
5458 nlsys = mtkcompile (System (new_eqs, new_dvs, new_ps; name = :nlsys ); split = is_split (sys))
5559 return nlsys, outer_tmp, inner_tmp
5660end
0 commit comments