Skip to content

Commit 16f3836

Browse files
committed
setup tstops for ODE solvers
1 parent 0846c9e commit 16f3836

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

src/time_evolution/callback_helpers/callback_helpers.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,20 @@ This file contains helper functions for callbacks. The affect! function are defi
66

77
abstract type AbstractSaveFunc end
88

9+
function _kwargs_set_tstops(kwargs, tlist)
10+
tstops = haskey(kwargs, :tstops) ? unique!(sort!(vcat(tlist, kwargs.tstops))) : tlist
11+
return merge(kwargs, (tstops = tstops,))
12+
end
13+
914
# Multiple dispatch depending on the progress_bar and e_ops types
1015
function _generate_se_me_kwargs(e_ops, progress_bar, tlist, kwargs, method)
1116
cb = _generate_save_callback(e_ops, tlist, progress_bar, method)
12-
return _merge_kwargs_with_callback(kwargs, cb)
17+
18+
kwargs2 = _kwargs_set_tstops(kwargs, tlist)
19+
return _merge_kwargs_with_callback(kwargs2, cb)
1320
end
14-
_generate_se_me_kwargs(e_ops::Nothing, progress_bar::Val{false}, tlist, kwargs, method) = kwargs
21+
_generate_se_me_kwargs(e_ops::Nothing, progress_bar::Val{false}, tlist, kwargs, method) =
22+
_kwargs_set_tstops(kwargs, tlist)
1523

1624
function _generate_stochastic_kwargs(
1725
e_ops,
@@ -26,8 +34,7 @@ function _generate_stochastic_kwargs(
2634

2735
# Ensure that the noise is stored in tlist. # TODO: Fix this directly in DiffEqNoiseProcess.jl
2836
# See https://github.com/SciML/DiffEqNoiseProcess.jl/issues/214 for example
29-
tstops = haskey(kwargs, :tstops) ? unique!(sort!(vcat(tlist, kwargs.tstops))) : tlist
30-
kwargs2 = merge(kwargs, (tstops = tstops,))
37+
kwargs2 = _kwargs_set_tstops(kwargs, tlist)
3138

3239
if SF === SaveFuncSSESolve
3340
cb_normalize = _ssesolve_generate_normalize_cb()
@@ -44,7 +51,7 @@ _generate_stochastic_kwargs(
4451
store_measurement::Val{false},
4552
kwargs,
4653
method::Type{SF},
47-
) where {SF<:AbstractSaveFunc} = kwargs
54+
) where {SF<:AbstractSaveFunc} = _kwargs_set_tstops(kwargs, tlist)
4855

4956
function _merge_kwargs_with_callback(kwargs, cb)
5057
kwargs2 =

src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,21 +105,22 @@ function _generate_mcsolve_kwargs(ψ0, T, e_ops, tlist, c_ops, jump_callback, rn
105105
)
106106
end
107107

108+
kwargs2 = _kwargs_set_tstops(kwargs, tlist)
108109
if e_ops isa Nothing
109110
# We are implicitly saying that we don't have a `Progress`
110-
kwargs2 =
111-
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, kwargs.callback),)) :
112-
merge(kwargs, (callback = cb1,))
113-
return kwargs2
111+
kwargs3 =
112+
haskey(kwargs2, :callback) ? merge(kwargs2, (callback = CallbackSet(cb1, kwargs2.callback),)) :
113+
merge(kwargs2, (callback = cb1,))
114+
return kwargs3
114115
else
115116
expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist))
116117

117118
_save_func = SaveFuncMCSolve(get_data.(e_ops), Ref(1), expvals)
118119
cb2 = FunctionCallingCallback(_save_func, funcat = tlist)
119-
kwargs2 =
120-
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, cb2, kwargs.callback),)) :
121-
merge(kwargs, (callback = CallbackSet(cb1, cb2),))
122-
return kwargs2
120+
kwargs3 =
121+
haskey(kwargs2, :callback) ? merge(kwargs2, (callback = CallbackSet(cb1, cb2, kwargs2.callback),)) :
122+
merge(kwargs2, (callback = CallbackSet(cb1, cb2),))
123+
return kwargs3
123124
end
124125
end
125126

0 commit comments

Comments
 (0)