Skip to content

Commit 300fd5d

Browse files
Change save callbacks from PresetTimeCallback to FunctionCallingCallback (#410)
Co-authored-by: Yi-Te Huang <[email protected]>
1 parent d48371c commit 300fd5d

File tree

9 files changed

+113
-64
lines changed

9 files changed

+113
-64
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## [Unreleased](https://github.com/qutip/QuantumToolbox.jl/tree/main)
99

1010
- Support for single `AbstractQuantumObject` in `sc_ops` for faster specific method in `ssesolve` and `smesolve`. ([#408])
11+
- Change save callbacks from `PresetTimeCallback` to `FunctionCallingCallback`. ([#410])
1112
- Align `eigenstates` and `eigenenergies` to QuTiP. ([#411])
1213

1314
## [v0.27.0]
@@ -144,4 +145,5 @@ Release date: 2024-11-13
144145
[#404]: https://github.com/qutip/QuantumToolbox.jl/issues/404
145146
[#405]: https://github.com/qutip/QuantumToolbox.jl/issues/405
146147
[#408]: https://github.com/qutip/QuantumToolbox.jl/issues/408
148+
[#410]: https://github.com/qutip/QuantumToolbox.jl/issues/410
147149
[#411]: https://github.com/qutip/QuantumToolbox.jl/issues/411

src/QuantumToolbox.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ import SciMLOperators:
5353
concretize
5454
import LinearSolve: LinearProblem, SciMLLinearSolveAlgorithm, KrylovJL_MINRES, KrylovJL_GMRES
5555
import DiffEqBase: get_tstops
56-
import DiffEqCallbacks: PeriodicCallback, PresetTimeCallback, TerminateSteadyState
56+
import DiffEqCallbacks: PeriodicCallback, FunctionCallingCallback, FunctionCallingAffect, TerminateSteadyState
5757
import OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm
5858
import OrdinaryDiffEqTsit5: Tsit5
5959
import DiffEqNoiseProcess: RealWienerProcess!, RealWienerProcess

src/time_evolution/callback_helpers/callback_helpers.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ function _generate_save_callback(e_ops, tlist, progress_bar, method)
5656

5757
expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist))
5858

59-
_save_affect! = method(e_ops_data, progr, Ref(1), expvals)
60-
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
59+
_save_func = method(e_ops_data, progr, Ref(1), expvals)
60+
return FunctionCallingCallback(_save_func, funcat = tlist)
6161
end
6262

6363
function _generate_stochastic_save_callback(e_ops, sc_ops, tlist, store_measurement, progress_bar, method)
@@ -69,8 +69,8 @@ function _generate_stochastic_save_callback(e_ops, sc_ops, tlist, store_measurem
6969
expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist))
7070
m_expvals = getVal(store_measurement) ? Array{Float64}(undef, length(sc_ops), length(tlist) - 1) : nothing
7171

72-
_save_affect! = method(store_measurement, e_ops_data, m_ops_data, progr, Ref(1), expvals, m_expvals)
73-
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
72+
_save_func = method(store_measurement, e_ops_data, m_ops_data, progr, Ref(1), expvals, m_expvals)
73+
return FunctionCallingCallback(_save_func, funcat = tlist)
7474
end
7575

7676
##
@@ -98,20 +98,20 @@ function _get_m_expvals(integrator::AbstractODESolution, method::Type{SF}) where
9898
if cb isa Nothing
9999
return nothing
100100
else
101-
return cb.affect!.m_expvals
101+
return cb.affect!.func.m_expvals
102102
end
103103
end
104104

105105
#=
106106
With this function we extract the e_ops from the SaveFuncMCSolve `affect!` function of the callback of the integrator.
107-
This callback can only be a PresetTimeCallback (DiscreteCallback).
107+
This callback can only be a FunctionCallingCallback (DiscreteCallback).
108108
=#
109109
function _get_e_ops(integrator::AbstractODEIntegrator, method::Type{SF}) where {SF<:AbstractSaveFunc}
110110
cb = _get_save_callback(integrator, method)
111111
if cb isa Nothing
112112
return nothing
113113
else
114-
return cb.affect!.e_ops
114+
return cb.affect!.func.e_ops
115115
end
116116
end
117117

@@ -121,7 +121,7 @@ function _get_expvals(sol::AbstractODESolution, method::Type{SF}) where {SF<:Abs
121121
if cb isa Nothing
122122
return nothing
123123
else
124-
return cb.affect!.expvals
124+
return cb.affect!.func.expvals
125125
end
126126
end
127127

@@ -151,7 +151,7 @@ function _get_save_callback(cb::CallbackSet, method::Type{SF}) where {SF<:Abstra
151151
end
152152
end
153153
function _get_save_callback(cb::DiscreteCallback, ::Type{SF}) where {SF<:AbstractSaveFunc}
154-
if typeof(cb.affect!) <: AbstractSaveFunc
154+
if typeof(cb.affect!) <: FunctionCallingAffect && typeof(cb.affect!.func) <: AbstractSaveFunc
155155
return cb
156156
end
157157
return nothing

src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ struct SaveFuncMCSolve{TE,IT,TEXPV} <: AbstractSaveFunc
88
expvals::TEXPV
99
end
1010

11-
(f::SaveFuncMCSolve)(integrator) = _save_func_mcsolve(integrator, f.e_ops, f.iter, f.expvals)
11+
(f::SaveFuncMCSolve)(u, t, integrator) = _save_func_mcsolve(u, integrator, f.e_ops, f.iter, f.expvals)
1212

1313
_get_save_callback_idx(cb, ::Type{SaveFuncMCSolve}) = _mcsolve_has_continuous_jump(cb) ? 1 : 2
1414

@@ -52,10 +52,10 @@ end
5252

5353
##
5454

55-
function _save_func_mcsolve(integrator, e_ops, iter, expvals)
55+
function _save_func_mcsolve(u, integrator, e_ops, iter, expvals)
5656
cache_mc = _mc_get_jump_callback(integrator).affect!.cache_mc
5757

58-
copyto!(cache_mc, integrator.u)
58+
copyto!(cache_mc, u)
5959
normalize!(cache_mc)
6060
ψ = cache_mc
6161
_expect = op -> dot(ψ, op, ψ)
@@ -114,8 +114,8 @@ function _generate_mcsolve_kwargs(ψ0, T, e_ops, tlist, c_ops, jump_callback, rn
114114
else
115115
expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist))
116116

117-
_save_affect! = SaveFuncMCSolve(get_data.(e_ops), Ref(1), expvals)
118-
cb2 = PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
117+
_save_func = SaveFuncMCSolve(get_data.(e_ops), Ref(1), expvals)
118+
cb2 = FunctionCallingCallback(_save_func, funcat = tlist)
119119
kwargs2 =
120120
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, cb2, kwargs.callback),)) :
121121
merge(kwargs, (callback = CallbackSet(cb1, cb2),))
@@ -214,11 +214,11 @@ function _mcsolve_initialize_callbacks(cb::CallbackSet, tlist, traj_rng)
214214

215215
if _mcsolve_has_continuous_jump(cb)
216216
idx = 1
217-
if cb_discrete[idx].affect! isa SaveFuncMCSolve
218-
e_ops = cb_discrete[idx].affect!.e_ops
219-
expvals = similar(cb_discrete[idx].affect!.expvals)
220-
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals)
221-
cb_save = (PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)),)
217+
if cb_discrete[idx].affect!.func isa SaveFuncMCSolve
218+
e_ops = cb_discrete[idx].affect!.func.e_ops
219+
expvals = similar(cb_discrete[idx].affect!.func.expvals)
220+
_save_func = SaveFuncMCSolve(e_ops, Ref(1), expvals)
221+
cb_save = (FunctionCallingCallback(_save_func, funcat = tlist),)
222222
else
223223
cb_save = ()
224224
end
@@ -229,11 +229,11 @@ function _mcsolve_initialize_callbacks(cb::CallbackSet, tlist, traj_rng)
229229
return CallbackSet((cb_jump, cb_continuous[2:end]...), (cb_save..., cb_discrete[2:end]...))
230230
else
231231
idx = 2
232-
if cb_discrete[idx].affect! isa SaveFuncMCSolve
233-
e_ops = cb_discrete[idx].affect!.e_ops
234-
expvals = similar(cb_discrete[idx].affect!.expvals)
235-
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals)
236-
cb_save = (PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)),)
232+
if cb_discrete[idx].affect!.func isa SaveFuncMCSolve
233+
e_ops = cb_discrete[idx].affect!.func.e_ops
234+
expvals = similar(cb_discrete[idx].affect!.func.expvals)
235+
_save_func = SaveFuncMCSolve(e_ops, Ref(1), expvals)
236+
cb_save = (FunctionCallingCallback(_save_func, funcat = tlist),)
237237
else
238238
cb_save = ()
239239
end

src/time_evolution/callback_helpers/mesolve_callback_helpers.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@ struct SaveFuncMESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing
99
expvals::TEXPV
1010
end
1111

12-
(f::SaveFuncMESolve)(integrator) = _save_func_mesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals)
13-
(f::SaveFuncMESolve{Nothing})(integrator) = _save_func(integrator, f.progr)
12+
(f::SaveFuncMESolve)(u, t, integrator) = _save_func_mesolve(u, integrator, f.e_ops, f.progr, f.iter, f.expvals)
13+
(f::SaveFuncMESolve{Nothing})(u, t, integrator) = _save_func(integrator, f.progr)
1414

1515
_get_e_ops_data(e_ops, ::Type{SaveFuncMESolve}) = [_generate_mesolve_e_op(op) for op in e_ops] # Broadcasting generates type instabilities on Julia v1.10
1616

1717
##
1818

1919
# When e_ops is a list of operators
20-
function _save_func_mesolve(integrator, e_ops, progr, iter, expvals)
20+
function _save_func_mesolve(u, integrator, e_ops, progr, iter, expvals)
2121
# This is equivalent to tr(op * ρ), when both are matrices.
2222
# The advantage of using this convention is that We don't need
2323
# to reshape u to make it a matrix, but we reshape the e_ops once.
2424

25-
ρ = integrator.u
25+
ρ = u
2626
_expect = op -> dot(op, ρ)
2727
@. expvals[:, iter[]] = _expect(e_ops)
2828
iter[] += 1
@@ -35,7 +35,7 @@ function _mesolve_callbacks_new_e_ops!(integrator::AbstractODEIntegrator, e_ops)
3535
if cb isa Nothing
3636
return nothing
3737
else
38-
cb.affect!.e_ops .= e_ops # Only works if e_ops is a Vector of operators
38+
cb.affect!.func.e_ops .= e_ops # Only works if e_ops is a Vector of operators
3939
return nothing
4040
end
4141
end

src/time_evolution/callback_helpers/sesolve_callback_helpers.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@ struct SaveFuncSESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing
99
expvals::TEXPV
1010
end
1111

12-
(f::SaveFuncSESolve)(integrator) = _save_func_sesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals)
13-
(f::SaveFuncSESolve{Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both mesolve and sesolve
12+
(f::SaveFuncSESolve)(u, t, integrator) = _save_func_sesolve(u, integrator, f.e_ops, f.progr, f.iter, f.expvals)
13+
(f::SaveFuncSESolve{Nothing})(u, t, integrator) = _save_func(integrator, f.progr) # Common for both mesolve and sesolve
1414

1515
_get_e_ops_data(e_ops, ::Type{SaveFuncSESolve}) = get_data.(e_ops)
1616

1717
##
1818

1919
# When e_ops is a list of operators
20-
function _save_func_sesolve(integrator, e_ops, progr, iter, expvals)
21-
ψ = integrator.u
20+
function _save_func_sesolve(u, integrator, e_ops, progr, iter, expvals)
21+
ψ = u
2222
_expect = op -> dot(ψ, op, ψ)
2323
@. expvals[:, iter[]] = _expect(e_ops)
2424
iter[] += 1

src/time_evolution/callback_helpers/smesolve_callback_helpers.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ struct SaveFuncSMESolve{
2020
m_expvals::TMEXPV
2121
end
2222

23-
(f::SaveFuncSMESolve)(integrator) =
24-
_save_func_smesolve(integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals)
25-
(f::SaveFuncSMESolve{false,Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both all solvers
23+
(f::SaveFuncSMESolve)(u, t, integrator) =
24+
_save_func_smesolve(u, integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals)
25+
(f::SaveFuncSMESolve{false,Nothing})(u, t, integrator) = _save_func(integrator, f.progr) # Common for both all solvers
2626

2727
_get_e_ops_data(e_ops, ::Type{SaveFuncSMESolve}) = _get_e_ops_data(e_ops, SaveFuncMESolve)
2828
_get_m_ops_data(sc_ops, ::Type{SaveFuncSMESolve}) =
@@ -31,12 +31,12 @@ _get_m_ops_data(sc_ops, ::Type{SaveFuncSMESolve}) =
3131
##
3232

3333
# When e_ops is a list of operators
34-
function _save_func_smesolve(integrator, e_ops, m_ops, progr, iter, expvals, m_expvals)
34+
function _save_func_smesolve(u, integrator, e_ops, m_ops, progr, iter, expvals, m_expvals)
3535
# This is equivalent to tr(op * ρ), when both are matrices.
3636
# The advantage of using this convention is that We don't need
3737
# to reshape u to make it a matrix, but we reshape the e_ops once.
3838

39-
ρ = integrator.u
39+
ρ = u
4040

4141
_expect = op -> dot(op, ρ)
4242

src/time_evolution/callback_helpers/ssesolve_callback_helpers.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ struct SaveFuncSSESolve{
2020
m_expvals::TMEXPV
2121
end
2222

23-
(f::SaveFuncSSESolve)(integrator) =
24-
_save_func_ssesolve(integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals)
25-
(f::SaveFuncSSESolve{false,Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both all solvers
23+
(f::SaveFuncSSESolve)(u, t, integrator) =
24+
_save_func_ssesolve(u, integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals)
25+
(f::SaveFuncSSESolve{false,Nothing})(u, t, integrator) = _save_func(integrator, f.progr) # Common for both all solvers
2626

2727
_get_e_ops_data(e_ops, ::Type{SaveFuncSSESolve}) = get_data.(e_ops)
2828
_get_m_ops_data(sc_ops, ::Type{SaveFuncSSESolve}) = map(op -> Hermitian(get_data(op) + get_data(op)'), sc_ops)
@@ -32,8 +32,8 @@ _get_save_callback_idx(cb, ::Type{SaveFuncSSESolve}) = 2 # The first one is the
3232
##
3333

3434
# When e_ops is a list of operators
35-
function _save_func_ssesolve(integrator, e_ops, m_ops, progr, iter, expvals, m_expvals)
36-
ψ = integrator.u
35+
function _save_func_ssesolve(u, integrator, e_ops, m_ops, progr, iter, expvals, m_expvals)
36+
ψ = u
3737

3838
_expect = op -> dot(ψ, op, ψ)
3939

0 commit comments

Comments
 (0)