Skip to content

Commit ed5b0da

Browse files
Align some attributes of mcsolve, ssesolve and smesolve results with QuTiP (#402)
1 parent dff6aa1 commit ed5b0da

File tree

7 files changed

+69
-60
lines changed

7 files changed

+69
-60
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Release date: 2025-01-29
3434
- Fix Dynamical Fock Dimension states saving due to wrong saving of dimensions. ([#375])
3535
- Support a list of observables for `expect`. ([#374], [#376])
3636
- Add checks for `tlist` in time evolution solvers. The checks are to ensure that `tlist` is not empty, the elements are in increasing order, and the elements are unique. ([#378])
37+
- Change the definition of jump_times and jump_which into col_times and col_which, respectively. ([#402])
3738

3839
## [v0.25.0]
3940
Release date: 2025-01-20
@@ -129,4 +130,5 @@ Release date: 2024-11-13
129130
[#395]: https://github.com/qutip/QuantumToolbox.jl/issues/395
130131
[#396]: https://github.com/qutip/QuantumToolbox.jl/issues/396
131132
[#398]: https://github.com/qutip/QuantumToolbox.jl/issues/398
133+
[#402]: https://github.com/qutip/QuantumToolbox.jl/issues/402
132134
[#403]: https://github.com/qutip/QuantumToolbox.jl/issues/403

src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ struct LindbladJump{
2828
cache_mc::CT
2929
weights_mc::WT
3030
cumsum_weights_mc::WT
31-
jump_times::JTT
32-
jump_which::JWT
33-
jump_times_which_idx::JTWIT
31+
col_times::JTT
32+
col_which::JWT
33+
col_times_which_idx::JTWIT
3434
end
3535

3636
(f::LindbladJump)(integrator) = _lindblad_jump_affect!(
@@ -42,9 +42,9 @@ end
4242
f.cache_mc,
4343
f.weights_mc,
4444
f.cumsum_weights_mc,
45-
f.jump_times,
46-
f.jump_which,
47-
f.jump_times_which_idx,
45+
f.col_times,
46+
f.col_which,
47+
f.col_times_which_idx,
4848
)
4949

5050
##
@@ -71,9 +71,9 @@ function _generate_mcsolve_kwargs(ψ0, T, e_ops, tlist, c_ops, jump_callback, rn
7171
weights_mc = Vector{Float64}(undef, length(c_ops))
7272
cumsum_weights_mc = similar(weights_mc)
7373

74-
jump_times = Vector{Float64}(undef, JUMP_TIMES_WHICH_INIT_SIZE)
75-
jump_which = Vector{Int}(undef, JUMP_TIMES_WHICH_INIT_SIZE)
76-
jump_times_which_idx = Ref(1)
74+
col_times = Vector{Float64}(undef, COL_TIMES_WHICH_INIT_SIZE)
75+
col_which = Vector{Int}(undef, COL_TIMES_WHICH_INIT_SIZE)
76+
col_times_which_idx = Ref(1)
7777

7878
random_n = Ref(rand(rng))
7979

@@ -85,9 +85,9 @@ function _generate_mcsolve_kwargs(ψ0, T, e_ops, tlist, c_ops, jump_callback, rn
8585
cache_mc,
8686
weights_mc,
8787
cumsum_weights_mc,
88-
jump_times,
89-
jump_which,
90-
jump_times_which_idx,
88+
col_times,
89+
col_which,
90+
col_times_which_idx,
9191
)
9292

9393
if jump_callback isa DiscreteLindbladJumpCallback
@@ -129,9 +129,9 @@ function _lindblad_jump_affect!(
129129
cache_mc,
130130
weights_mc,
131131
cumsum_weights_mc,
132-
jump_times,
133-
jump_which,
134-
jump_times_which_idx,
132+
col_times,
133+
col_which,
134+
col_times_which_idx,
135135
)
136136
ψ = integrator.u
137137

@@ -147,13 +147,13 @@ function _lindblad_jump_affect!(
147147

148148
random_n[] = rand(traj_rng)
149149

150-
idx = jump_times_which_idx[]
151-
@inbounds jump_times[idx] = integrator.t
152-
@inbounds jump_which[idx] = collapse_idx
153-
jump_times_which_idx[] += 1
154-
if jump_times_which_idx[] > length(jump_times)
155-
resize!(jump_times, length(jump_times) + JUMP_TIMES_WHICH_INIT_SIZE)
156-
resize!(jump_which, length(jump_which) + JUMP_TIMES_WHICH_INIT_SIZE)
150+
idx = col_times_which_idx[]
151+
@inbounds col_times[idx] = integrator.t
152+
@inbounds col_which[idx] = collapse_idx
153+
col_times_which_idx[] += 1
154+
if col_times_which_idx[] > length(col_times)
155+
resize!(col_times, length(col_times) + COL_TIMES_WHICH_INIT_SIZE)
156+
resize!(col_which, length(col_which) + COL_TIMES_WHICH_INIT_SIZE)
157157
end
158158
u_modified!(integrator, true)
159159
return nothing
@@ -309,9 +309,9 @@ function _similar_affect!(affect::LindbladJump, traj_rng)
309309
cache_mc = similar(affect.cache_mc)
310310
weights_mc = similar(affect.weights_mc)
311311
cumsum_weights_mc = similar(affect.cumsum_weights_mc)
312-
jump_times = similar(affect.jump_times)
313-
jump_which = similar(affect.jump_which)
314-
jump_times_which_idx = Ref(1)
312+
col_times = similar(affect.col_times)
313+
col_which = similar(affect.col_which)
314+
col_times_which_idx = Ref(1)
315315

316316
return LindbladJump(
317317
affect.c_ops,
@@ -321,9 +321,9 @@ function _similar_affect!(affect::LindbladJump, traj_rng)
321321
cache_mc,
322322
weights_mc,
323323
cumsum_weights_mc,
324-
jump_times,
325-
jump_which,
326-
jump_times_which_idx,
324+
col_times,
325+
col_which,
326+
col_times_which_idx,
327327
)
328328
end
329329

src/time_evolution/mcsolve.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ end
1414

1515
# Standard output function
1616
function _mcsolve_output_func(sol, i)
17-
idx = _mc_get_jump_callback(sol).affect!.jump_times_which_idx[]
18-
resize!(_mc_get_jump_callback(sol).affect!.jump_times, idx - 1)
19-
resize!(_mc_get_jump_callback(sol).affect!.jump_which, idx - 1)
17+
idx = _mc_get_jump_callback(sol).affect!.col_times_which_idx[]
18+
resize!(_mc_get_jump_callback(sol).affect!.col_times, idx - 1)
19+
resize!(_mc_get_jump_callback(sol).affect!.col_which, idx - 1)
2020
return (sol, false)
2121
end
2222

@@ -401,21 +401,22 @@ function mcsolve(
401401
_expvals_sol_1 = _mcsolve_get_expvals(_sol_1)
402402

403403
_expvals_all = _expvals_sol_1 isa Nothing ? nothing : map(i -> _mcsolve_get_expvals(sol[:, i]), eachindex(sol))
404-
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all)
404+
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP
405405
states = map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol))
406-
jump_times = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.jump_times, eachindex(sol))
407-
jump_which = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.jump_which, eachindex(sol))
406+
col_times = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_times, eachindex(sol))
407+
col_which = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_which, eachindex(sol))
408408

409-
expvals = _expvals_sol_1 isa Nothing ? nothing : dropdims(sum(expvals_all, dims = 3), dims = 3) ./ length(sol)
409+
expvals = _expvals_sol_1 isa Nothing ? nothing : dropdims(sum(expvals_all, dims = 2), dims = 2) ./ length(sol)
410410

411411
return TimeEvolutionMCSol(
412412
ntraj,
413413
ens_prob_mc.times,
414414
states,
415415
expvals,
416+
expvals, # This is average_expect
416417
expvals_all,
417-
jump_times,
418-
jump_which,
418+
col_times,
419+
col_which,
419420
sol.converged,
420421
_sol_1.alg,
421422
NamedTuple(_sol_1.prob.kwargs).abstol,

src/time_evolution/smesolve.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,18 +361,19 @@ function smesolve(
361361

362362
dims = ens_prob.dimensions
363363
_expvals_all = _expvals_sol_1 isa Nothing ? nothing : map(i -> _se_me_sse_get_expvals(sol[:, i]), eachindex(sol))
364-
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all)
364+
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP
365365
states = map(i -> _smesolve_generate_state.(sol[:, i].u, Ref(dims)), eachindex(sol))
366366

367367
expvals =
368368
_se_me_sse_get_expvals(_sol_1) isa Nothing ? nothing :
369-
dropdims(sum(expvals_all, dims = 3), dims = 3) ./ length(sol)
369+
dropdims(sum(expvals_all, dims = 2), dims = 2) ./ length(sol)
370370

371371
return TimeEvolutionStochasticSol(
372372
ntraj,
373373
ens_prob.times,
374374
states,
375375
expvals,
376+
expvals, # This is average_expect
376377
expvals_all,
377378
sol.converged,
378379
_sol_1.alg,

src/time_evolution/ssesolve.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,18 +362,19 @@ function ssesolve(
362362
normalize_states = Val(false)
363363
dims = ens_prob.dimensions
364364
_expvals_all = _expvals_sol_1 isa Nothing ? nothing : map(i -> _se_me_sse_get_expvals(sol[:, i]), eachindex(sol))
365-
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all)
365+
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP
366366
states = map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol))
367367

368368
expvals =
369369
_se_me_sse_get_expvals(_sol_1) isa Nothing ? nothing :
370-
dropdims(sum(expvals_all, dims = 3), dims = 3) ./ length(sol)
370+
dropdims(sum(expvals_all, dims = 2), dims = 2) ./ length(sol)
371371

372372
return TimeEvolutionStochasticSol(
373373
ntraj,
374374
ens_prob.times,
375375
states,
376376
expvals,
377+
expvals, # This is average_expect
377378
expvals_all,
378379
sol.converged,
379380
_sol_1.alg,

src/time_evolution/time_evolution.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ export liouvillian_floquet, liouvillian_generalized
44

55
const DEFAULT_ODE_SOLVER_OPTIONS = (abstol = 1e-8, reltol = 1e-6, save_everystep = false, save_end = true)
66
const DEFAULT_SDE_SOLVER_OPTIONS = (abstol = 1e-2, reltol = 1e-2, save_everystep = false, save_end = true)
7-
const JUMP_TIMES_WHICH_INIT_SIZE = 200
7+
const COL_TIMES_WHICH_INIT_SIZE = 200
88

99
@doc raw"""
1010
struct TimeEvolutionProblem
@@ -99,9 +99,10 @@ A structure storing the results and some information from solving quantum trajec
9999
- `times::AbstractVector`: The time list of the evolution.
100100
- `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory.
101101
- `expect::Union{AbstractMatrix,Nothing}`: The expectation values (averaging all trajectories) corresponding to each time point in `times`.
102-
- `expect_all::Union{AbstractMatrix,Nothing}`: The expectation values corresponding to each trajectory and each time point in `times`
103-
- `jump_times::Vector{Vector{Real}}`: The time records of every quantum jump occurred in each trajectory.
104-
- `jump_which::Vector{Vector{Int}}`: The indices of the jump operators in `c_ops` that describe the corresponding quantum jumps occurred in each trajectory.
102+
- `average_expect::Union{AbstractMatrix,Nothing}`: The expectation values (averaging all trajectories) corresponding to each time point in `times`.
103+
- `runs_expect::Union{AbstractArray,Nothing}`: The expectation values corresponding to each trajectory and each time point in `times`
104+
- `col_times::Vector{Vector{Real}}`: The time records of every quantum jump occurred in each trajectory.
105+
- `col_which::Vector{Vector{Int}}`: The indices of which collapse operator was responsible for each quantum jump in `col_times`.
105106
- `converged::Bool`: Whether the solution is converged or not.
106107
- `alg`: The algorithm which is used during the solving process.
107108
- `abstol::Real`: The absolute tolerance which is used during the solving process.
@@ -122,9 +123,10 @@ struct TimeEvolutionMCSol{
122123
times::TT
123124
states::TS
124125
expect::TE
125-
expect_all::TEA
126-
jump_times::TJT
127-
jump_which::TJW
126+
average_expect::TE # Currently just a synonym for `expect`
127+
runs_expect::TEA
128+
col_times::TJT
129+
col_which::TJW
128130
converged::Bool
129131
alg::AlgT
130132
abstol::AT
@@ -140,7 +142,7 @@ function Base.show(io::IO, sol::TimeEvolutionMCSol)
140142
if sol.expect isa Nothing
141143
print(io, "num_expect = 0\n")
142144
else
143-
print(io, "num_expect = $(size(sol.expect, 1))\n")
145+
print(io, "num_expect = $(size(sol.average_expect, 1))\n")
144146
end
145147
print(io, "ODE alg.: $(sol.alg)\n")
146148
print(io, "abstol = $(sol.abstol)\n")
@@ -159,7 +161,8 @@ A structure storing the results and some information from solving trajectories o
159161
- `times::AbstractVector`: The time list of the evolution.
160162
- `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory.
161163
- `expect::Union{AbstractMatrix,Nothing}`: The expectation values (averaging all trajectories) corresponding to each time point in `times`.
162-
- `expect_all::Union{AbstractArray,Nothing}`: The expectation values corresponding to each trajectory and each time point in `times`
164+
- `average_expect::Union{AbstractMatrix,Nothing}`: The expectation values (averaging all trajectories) corresponding to each time point in `times`.
165+
- `runs_expect::Union{AbstractArray,Nothing}`: The expectation values corresponding to each trajectory and each time point in `times`
163166
- `converged::Bool`: Whether the solution is converged or not.
164167
- `alg`: The algorithm which is used during the solving process.
165168
- `abstol::Real`: The absolute tolerance which is used during the solving process.
@@ -178,7 +181,8 @@ struct TimeEvolutionStochasticSol{
178181
times::TT
179182
states::TS
180183
expect::TE
181-
expect_all::TEA
184+
average_expect::TE # Currently just a synonym for `expect`
185+
runs_expect::TEA
182186
converged::Bool
183187
alg::AlgT
184188
abstol::AT
@@ -194,7 +198,7 @@ function Base.show(io::IO, sol::TimeEvolutionStochasticSol)
194198
if sol.expect isa Nothing
195199
print(io, "num_expect = 0\n")
196200
else
197-
print(io, "num_expect = $(size(sol.expect, 1))\n")
201+
print(io, "num_expect = $(size(sol.average_expect, 1))\n")
198202
end
199203
print(io, "SDE alg.: $(sol.alg)\n")
200204
print(io, "abstol = $(sol.abstol)\n")

test/core-test/time_evolution.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -660,21 +660,21 @@
660660
)
661661

662662
@test sol_mc1.expect sol_mc2.expect atol = 1e-10
663-
@test sol_mc1.expect_all sol_mc2.expect_all atol = 1e-10
664-
@test sol_mc1.jump_times sol_mc2.jump_times atol = 1e-10
665-
@test sol_mc1.jump_which sol_mc2.jump_which atol = 1e-10
663+
@test sol_mc1.runs_expect sol_mc2.runs_expect atol = 1e-10
664+
@test sol_mc1.col_times sol_mc2.col_times atol = 1e-10
665+
@test sol_mc1.col_which sol_mc2.col_which atol = 1e-10
666666

667-
@test sol_mc1.expect_all sol_mc3.expect_all[:, :, 1:500] atol = 1e-10
667+
@test sol_mc1.runs_expect sol_mc3.runs_expect[:, 1:500, :] atol = 1e-10
668668

669669
@test sol_sse1.expect sol_sse2.expect atol = 1e-10
670-
@test sol_sse1.expect_all sol_sse2.expect_all atol = 1e-10
670+
@test sol_sse1.runs_expect sol_sse2.runs_expect atol = 1e-10
671671

672-
@test sol_sse1.expect_all sol_sse3.expect_all[:, :, 1:50] atol = 1e-10
672+
@test sol_sse1.runs_expect sol_sse3.runs_expect[:, 1:50, :] atol = 1e-10
673673

674674
@test sol_sme1.expect sol_sme2.expect atol = 1e-10
675-
@test sol_sme1.expect_all sol_sme2.expect_all atol = 1e-10
675+
@test sol_sme1.runs_expect sol_sme2.runs_expect atol = 1e-10
676676

677-
@test sol_sme1.expect_all sol_sme3.expect_all[:, :, 1:50] atol = 1e-10
677+
@test sol_sme1.runs_expect sol_sme3.runs_expect[:, 1:50, :] atol = 1e-10
678678
end
679679
end
680680

0 commit comments

Comments
 (0)