Skip to content

Commit aaaf0d3

Browse files
authored
Fix time evolution output when using saveat (#398)
2 parents 0c0adcf + 72e7b79 commit aaaf0d3

File tree

8 files changed

+40
-34
lines changed

8 files changed

+40
-34
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
- Fix erroneous definition of the stochastic term in `smesolve`. ([#393])
1212
- Change name of `MultiSiteOperator` to `multisite_operator`. ([#394])
1313
- Fix `smesolve` for specifying initial state as density matrix. ([#395])
14+
- Fix time evolution output when using `saveat` keyword argument. ([#398])
1415

1516
## [v0.26.0]
1617
Release date: 2025-02-09
@@ -124,3 +125,4 @@ Release date: 2024-11-13
124125
[#393]: https://github.com/qutip/QuantumToolbox.jl/issues/393
125126
[#394]: https://github.com/qutip/QuantumToolbox.jl/issues/394
126127
[#395]: https://github.com/qutip/QuantumToolbox.jl/issues/395
128+
[#398]: https://github.com/qutip/QuantumToolbox.jl/issues/398

src/time_evolution/mcsolve.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,9 @@ function mcsolveProblem(
131131

132132
T = Base.promote_eltype(H_eff_evo, ψ0)
133133

134-
is_empty_e_ops = e_ops isa Nothing ? true : isempty(e_ops)
135-
136-
saveat = is_empty_e_ops ? tlist : [tlist[end]]
137134
# We disable the progress bar of the sesolveProblem because we use a global progress bar for all the trajectories
138-
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat, progress_bar = Val(false))
139-
kwargs2 = merge(default_values, kwargs)
135+
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., progress_bar = Val(false))
136+
kwargs2 = _merge_saveat(tlist, e_ops, default_values; kwargs...)
140137
kwargs3 = _generate_mcsolve_kwargs(ψ0, T, e_ops, tlist, c_ops, jump_callback, rng, kwargs2)
141138

142139
return sesolveProblem(H_eff_evo, ψ0, tlist; params = params, kwargs3...)

src/time_evolution/mesolve.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,7 @@ function mesolveProblem(
7979
ρ0 = to_dense(_CType(T), mat2vec(ket2dm(ψ0).data)) # Convert it to dense vector with complex element type
8080
L = L_evo.data
8181

82-
is_empty_e_ops = (e_ops isa Nothing) ? true : isempty(e_ops)
83-
84-
saveat = is_empty_e_ops ? tlist : [tlist[end]]
85-
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
86-
kwargs2 = merge(default_values, kwargs)
82+
kwargs2 = _merge_saveat(tlist, e_ops, DEFAULT_ODE_SOLVER_OPTIONS; kwargs...)
8783
kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncMESolve)
8884

8985
tspan = (tlist[1], tlist[end])

src/time_evolution/sesolve.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,7 @@ function sesolveProblem(
6969
ψ0 = to_dense(_CType(T), get_data(ψ0)) # Convert it to dense vector with complex element type
7070
U = H_evo.data
7171

72-
is_empty_e_ops = (e_ops isa Nothing) ? true : isempty(e_ops)
73-
74-
saveat = is_empty_e_ops ? tlist : [tlist[end]]
75-
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
76-
kwargs2 = merge(default_values, kwargs)
72+
kwargs2 = _merge_saveat(tlist, e_ops, DEFAULT_ODE_SOLVER_OPTIONS; kwargs...)
7773
kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncSESolve)
7874

7975
tspan = (tlist[1], tlist[end])

src/time_evolution/smesolve.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,7 @@ function smesolveProblem(
112112

113113
p = (progr = progr, times = tlist, Hdims = dims, n_sc_ops = length(sc_ops), params...)
114114

115-
is_empty_e_ops = (e_ops isa Nothing) ? true : isempty(e_ops)
116-
117-
saveat = is_empty_e_ops ? tlist : [tlist[end]]
118-
default_values = (DEFAULT_SDE_SOLVER_OPTIONS..., saveat = saveat)
119-
kwargs2 = merge(default_values, kwargs)
115+
kwargs2 = _merge_saveat(tlist, e_ops, DEFAULT_SDE_SOLVER_OPTIONS; kwargs...)
120116
kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncMESolve)
121117

122118
tspan = (tlist[1], tlist[end])

src/time_evolution/ssesolve.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,7 @@ function ssesolveProblem(
186186

187187
p = (progr = progr, times = tlist, Hdims = dims, n_sc_ops = length(sc_ops), params...)
188188

189-
is_empty_e_ops = (e_ops isa Nothing) ? true : isempty(e_ops)
190-
191-
saveat = is_empty_e_ops ? tlist : [tlist[end]]
192-
default_values = (DEFAULT_SDE_SOLVER_OPTIONS..., saveat = saveat)
193-
kwargs2 = merge(default_values, kwargs)
189+
kwargs2 = _merge_saveat(tlist, e_ops, DEFAULT_SDE_SOLVER_OPTIONS; kwargs...)
194190
kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncSSESolve)
195191
kwargs4 = _ssesolve_add_normalize_cb(kwargs3)
196192

src/time_evolution/time_evolution.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,23 @@ function _check_tlist(tlist, T::Type)
230230
return tlist2
231231
end
232232

233+
#######################################
234+
235+
function _merge_saveat(tlist, e_ops, default_options; kwargs...)
236+
is_empty_e_ops = isnothing(e_ops) ? true : isempty(e_ops)
237+
saveat = is_empty_e_ops ? tlist : [tlist[end]]
238+
default_values = (default_options..., saveat = saveat)
239+
kwargs2 = merge(default_values, kwargs)
240+
241+
# DifferentialEquations.jl has this weird save_end setting
242+
# So we need to do this to make sure it's consistent
243+
haskey(kwargs, :save_end) && return kwargs2
244+
isempty(kwargs2.saveat) && return kwargs2
245+
246+
save_end = tlist[end] in kwargs2.saveat
247+
return merge(kwargs2, (save_end = save_end,))
248+
end
249+
233250
#######################################
234251
#=
235252
Helpers for handling output of ensemble problems.

test/core-test/time_evolution.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626

2727
@testset "sesolve" begin
2828
tlist = range(0, 20 * 2π / g, 1000)
29+
saveat_idxs = 500:900
30+
saveat = tlist[saveat_idxs]
2931

3032
prob = sesolveProblem(H, ψ0, tlist, e_ops = e_ops, progress_bar = Val(false))
3133
sol = sesolve(prob)
3234
sol2 = sesolve(H, ψ0, tlist, progress_bar = Val(false))
33-
sol3 = sesolve(H, ψ0, tlist, e_ops = e_ops, saveat = tlist, progress_bar = Val(false))
35+
sol3 = sesolve(H, ψ0, tlist, e_ops = e_ops, saveat = saveat, progress_bar = Val(false))
3436
sol_string = sprint((t, s) -> show(t, "text/plain", s), sol)
3537
sol_string2 = sprint((t, s) -> show(t, "text/plain", s), sol2)
3638

@@ -48,8 +50,9 @@
4850
@test length(sol2.states) == length(tlist)
4951
@test sol2.expect === nothing
5052
@test length(sol3.times) == length(tlist)
51-
@test length(sol3.states) == length(tlist)
53+
@test length(sol3.states) == length(saveat)
5254
@test size(sol3.expect) == (length(e_ops), length(tlist))
55+
@test sol.expect[1, saveat_idxs] expect(e_ops[1], sol3.states) atol = 1e-6
5356
@test sol_string ==
5457
"Solution of time evolution\n" *
5558
"(return code: $(sol.retcode))\n" *
@@ -92,18 +95,20 @@
9295
@inferred sesolveProblem(H, ψ0_int, tlist, progress_bar = Val(false))
9396
@inferred sesolve(H, ψ0, tlist, e_ops = e_ops, progress_bar = Val(false))
9497
@inferred sesolve(H, ψ0, tlist, progress_bar = Val(false))
95-
@inferred sesolve(H, ψ0, tlist, e_ops = e_ops, saveat = tlist, progress_bar = Val(false))
98+
@inferred sesolve(H, ψ0, tlist, e_ops = e_ops, saveat = saveat, progress_bar = Val(false))
9699
@inferred sesolve(H, ψ0, tlist, e_ops = (a' * a, a'), progress_bar = Val(false)) # We test the type inference for Tuple of different types
97100
end
98101
end
99102

100103
@testset "mesolve, mcsolve, ssesolve and smesolve" begin
101104
tlist = range(0, 10 / γ, 100)
105+
saveat_idxs = 50:90
106+
saveat = tlist[saveat_idxs]
102107

103108
prob_me = mesolveProblem(H, ψ0, tlist, c_ops, e_ops = e_ops, progress_bar = Val(false))
104109
sol_me = mesolve(prob_me)
105110
sol_me2 = mesolve(H, ψ0, tlist, c_ops, progress_bar = Val(false))
106-
sol_me3 = mesolve(H, ψ0, tlist, c_ops, e_ops = e_ops, saveat = tlist, progress_bar = Val(false))
111+
sol_me3 = mesolve(H, ψ0, tlist, c_ops, e_ops = e_ops, saveat = saveat, progress_bar = Val(false))
107112
prob_mc = mcsolveProblem(H, ψ0, tlist, c_ops, e_ops = e_ops, progress_bar = Val(false))
108113
sol_mc = mcsolve(H, ψ0, tlist, c_ops, ntraj = 500, e_ops = e_ops, progress_bar = Val(false))
109114
sol_mc2 = mcsolve(
@@ -116,14 +121,14 @@
116121
progress_bar = Val(false),
117122
jump_callback = DiscreteLindbladJumpCallback(),
118123
)
119-
sol_mc_states = mcsolve(H, ψ0, tlist, c_ops, ntraj = 500, saveat = tlist, progress_bar = Val(false))
124+
sol_mc_states = mcsolve(H, ψ0, tlist, c_ops, ntraj = 500, saveat = saveat, progress_bar = Val(false))
120125
sol_mc_states2 = mcsolve(
121126
H,
122127
ψ0,
123128
tlist,
124129
c_ops,
125130
ntraj = 500,
126-
saveat = tlist,
131+
saveat = saveat,
127132
progress_bar = Val(false),
128133
jump_callback = DiscreteLindbladJumpCallback(),
129134
)
@@ -147,8 +152,8 @@
147152
@test prob_mc.prob.f.f isa MatrixOperator
148153
@test sum(abs, sol_mc.expect .- sol_me.expect) / length(tlist) < 0.1
149154
@test sum(abs, sol_mc2.expect .- sol_me.expect) / length(tlist) < 0.1
150-
@test sum(abs, vec(expect_mc_states_mean) .- vec(sol_me.expect[1, :])) / length(tlist) < 0.1
151-
@test sum(abs, vec(expect_mc_states_mean2) .- vec(sol_me.expect[1, :])) / length(tlist) < 0.1
155+
@test sum(abs, vec(expect_mc_states_mean) .- vec(sol_me.expect[1, saveat_idxs])) / length(tlist) < 0.1
156+
@test sum(abs, vec(expect_mc_states_mean2) .- vec(sol_me.expect[1, saveat_idxs])) / length(tlist) < 0.1
152157
@test sum(abs, sol_sse.expect .- sol_me.expect) / length(tlist) < 0.1
153158
@test sum(abs, sol_sme.expect .- sol_me.expect) / length(tlist) < 0.1
154159
@test length(sol_me.times) == length(tlist)
@@ -158,8 +163,9 @@
158163
@test length(sol_me2.states) == length(tlist)
159164
@test sol_me2.expect === nothing
160165
@test length(sol_me3.times) == length(tlist)
161-
@test length(sol_me3.states) == length(tlist)
166+
@test length(sol_me3.states) == length(saveat)
162167
@test size(sol_me3.expect) == (length(e_ops), length(tlist))
168+
@test sol_me3.expect[1, saveat_idxs] expect(e_ops[1], sol_me3.states) atol = 1e-6
163169
@test length(sol_mc.times) == length(tlist)
164170
@test size(sol_mc.expect) == (length(e_ops), length(tlist))
165171
@test length(sol_mc_states.times) == length(tlist)

0 commit comments

Comments
 (0)