Skip to content

Commit 5bcd2e1

Browse files
Remove custom PresetTimeCallback
1 parent 024b4fd commit 5bcd2e1

File tree

1 file changed

+6
-54
lines changed

1 file changed

+6
-54
lines changed

src/time_evolution/callback_helpers.jl

Lines changed: 6 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ function _generate_sesolve_callback(e_ops, tlist, progress_bar)
4747
expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist))
4848

4949
_save_affect! = SaveFuncSESolve(e_ops_data, progr, Ref(1), expvals)
50-
return _PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
50+
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
5151
end
5252

5353
function _sesolve_get_expvals(sol::AbstractODESolution)
@@ -62,7 +62,8 @@ function _sesolve_get_expvals(cb::CallbackSet)
6262
_cb = cb.discrete_callbacks[1]
6363
return _sesolve_get_expvals(_cb)
6464
end
65-
_sesolve_get_expvals(cb::DiscreteCallback) = if cb.affect! isa SaveFuncSESolve
65+
_sesolve_get_expvals(cb::DiscreteCallback) =
66+
if cb.affect! isa SaveFuncSESolve
6667
return cb.affect!.expvals
6768
else
6869
return nothing
@@ -130,7 +131,7 @@ function _generate_mcsolve_kwargs(e_ops, tlist, c_ops, jump_callback, kwargs)
130131
expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist))
131132

132133
_save_affect! = SaveFuncMCSolve(get_data.(e_ops), Ref(1), expvals)
133-
cb2 = _PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
134+
cb2 = PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
134135
kwargs2 =
135136
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, cb2, kwargs.callback),)) :
136137
merge(kwargs, (callback = CallbackSet(cb1, cb2),))
@@ -241,14 +242,14 @@ function _mcsolve_callbacks_new_iter_expvals(cb::CallbackSet, tlist)
241242
e_ops = cb_discrete[idx].affect!.e_ops
242243
expvals = similar(cb_discrete[idx].affect!.expvals)
243244
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals)
244-
cb_save = _PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
245+
cb_save = PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
245246
return CallbackSet(cb_continuous..., cb_save, cb_discrete[2:end]...)
246247
else
247248
idx = 2
248249
e_ops = cb_discrete[idx].affect!.e_ops
249250
expvals = similar(cb_discrete[idx].affect!.expvals)
250251
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals)
251-
cb_save = _PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
252+
cb_save = PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
252253
return CallbackSet(cb_continuous..., cb_discrete[1], cb_save, cb_discrete[3:end]...)
253254
end
254255
end
@@ -263,52 +264,3 @@ _mcsolve_has_continuous_jump(cb::CallbackSet) =
263264
(length(cb.continuous_callbacks) > 0) && (cb.continuous_callbacks[1].affect! isa LindbladJump)
264265
_mcsolve_has_continuous_jump(cb::ContinuousCallback) = true
265266
_mcsolve_has_continuous_jump(cb::DiscreteCallback) = false
266-
267-
## Temporary function to avoid errors. Waiting for the PR In DiffEqCallbacks.jl to be merged.
268-
269-
import SciMLBase: INITIALIZE_DEFAULT, add_tstop!
270-
271-
function _PresetTimeCallback(
272-
tstops,
273-
user_affect!;
274-
initialize = INITIALIZE_DEFAULT,
275-
filter_tstops = true,
276-
sort_inplace = false,
277-
kwargs...,
278-
)
279-
if !(tstops isa AbstractVector) && !(tstops isa Number)
280-
throw(ArgumentError("tstops must either be a number or a Vector. Was $tstops"))
281-
end
282-
283-
tstops = tstops isa Number ? [tstops] : (sort_inplace ? sort!(tstops) : sort(tstops))
284-
285-
condition = let
286-
function (u, t, integrator)
287-
if hasproperty(integrator, :dt)
288-
insorted(t, tstops) && (integrator.t - integrator.dt) != integrator.t
289-
else
290-
insorted(t, tstops)
291-
end
292-
end
293-
end
294-
295-
# Initialization: first call to `f` should be *before* any time steps have been taken:
296-
initialize_preset = function (c, u, t, integrator)
297-
initialize(c, u, t, integrator)
298-
299-
if filter_tstops
300-
tdir = integrator.tdir
301-
tspan = integrator.sol.prob.tspan
302-
_tstops = tstops[@. tdir * tspan[1] < tdir * tstops < tdir * tspan[2]]
303-
else
304-
_tstops = tstops
305-
end
306-
for tstop in _tstops
307-
add_tstop!(integrator, tstop)
308-
end
309-
if t tstops
310-
user_affect!(integrator)
311-
end
312-
end
313-
return DiscreteCallback(condition, user_affect!; initialize = initialize_preset, kwargs...)
314-
end

0 commit comments

Comments
 (0)