Skip to content

Commit 0bc7621

Browse files
Add keep_runs_results option for multi-trajectory solvers to align with qutip (#512)
Co-authored-by: Alberto Mercurio <[email protected]>
1 parent c6283fe commit 0bc7621

File tree

14 files changed

+543
-206
lines changed

14 files changed

+543
-206
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
- Improve efficiency of `bloch_redfield_tensor` by avoiding unnecessary conversions. ([#509])
1111
- Support `SciMLOperators v1.4+`. ([#470])
1212
- Fix compatibility with `Makie v0.24+`. ([#513])
13+
- Add `keep_runs_results` option for multi-trajectory solvers to align with `QuTiP`. ([#512])
14+
- Breaking changes for multi-trajectory solutions:
15+
- the original fields `expect` and `states` now store the results depend on keyword argument `keep_runs_results` (decide whether to store all trajectories results or not).
16+
- remove field `average_expect`
17+
- remove field `runs_expect`
18+
- New statistical analysis functions:
19+
- `average_states`
20+
- `average_expect`
21+
- `std_expect`
1322

1423
## [v0.33.0]
1524
Release date: 2025-07-22
@@ -273,4 +282,5 @@ Release date: 2024-11-13
273282
[#506]: https://github.com/qutip/QuantumToolbox.jl/issues/506
274283
[#507]: https://github.com/qutip/QuantumToolbox.jl/issues/507
275284
[#509]: https://github.com/qutip/QuantumToolbox.jl/issues/509
285+
[#512]: https://github.com/qutip/QuantumToolbox.jl/issues/512
276286
[#513]: https://github.com/qutip/QuantumToolbox.jl/issues/513

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2424
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2525
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2626
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
27+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2728
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
2829

2930
[weakdeps]
@@ -65,5 +66,6 @@ SciMLOperators = "1.4"
6566
SparseArrays = "1"
6667
SpecialFunctions = "2"
6768
StaticArraysCore = "1"
69+
Statistics = "1"
6870
StochasticDiffEq = "6"
6971
julia = "1.10"

docs/src/resources/api.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,9 @@ TimeEvolutionProblem
197197
TimeEvolutionSol
198198
TimeEvolutionMCSol
199199
TimeEvolutionStochasticSol
200+
average_states
201+
average_expect
202+
std_expect
200203
sesolveProblem
201204
mesolveProblem
202205
mcsolveProblem

docs/src/users_guide/time_evolution/solution.md

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,47 @@ Some other solvers can have other output.
9797

9898
## [Multiple trajectories solution](@id doc-TE:Multiple-trajectories-solution)
9999

100-
This part is still under construction, please read the docstrings for the following functions first:
100+
The solutions are different for solvers which compute multiple trajectories, such as the [`TimeEvolutionMCSol`](@ref) (Monte Carlo) or the [`TimeEvolutionStochasticSol`](@ref) (stochastic methods). The storage of expectation values and states depends on the keyword argument `keep_runs_results`, which determines whether the results of all trajectories are stored in the solution.
101101

102-
- [`TimeEvolutionMCSol`](@ref)
103-
- [`TimeEvolutionStochasticSol`](@ref)
102+
When the keyword argument `keep_runs_results` is passed as `Val(false)` to a multi-trajectory solver, the `states` and `expect` fields store only the average results (averaged over all trajectories). The results can be accessed by the following index-order:
103+
104+
```julia
105+
sol.states[time_idx]
106+
sol.expect[e_op,time_idx]
107+
```
108+
109+
For example:
110+
111+
```@example TE-solution
112+
tlist = LinRange(0, 1, 11)
113+
c_ops = (destroy(2),)
114+
e_ops = (num(2),)
115+
116+
sol_mc1 = mcsolve(H, ψ0, tlist, c_ops, e_ops=e_ops, ntraj=25, keep_runs_results=Val(false), progress_bar=Val(false))
117+
118+
size(sol_mc1.expect)
119+
```
120+
121+
If the keyword argument `keep_runs_results = Val(true)`, the results for each trajectory and each time are stored, and the index-order of the elements in fields `states` and `expect` are:
122+
123+
124+
```julia
125+
sol.states[trajectory,time_idx]
126+
sol.expect[e_op,trajectory,time_idx]
127+
```
128+
129+
For example:
130+
131+
```@example TE-solution
132+
sol_mc2 = mcsolve(H, ψ0, tlist, c_ops, e_ops=e_ops, ntraj=25, keep_runs_results=Val(true), progress_bar=Val(false))
133+
134+
size(sol_mc2.expect)
135+
```
136+
137+
We also provide the following functions for statistical analysis of multi-trajectory `sol`utions:
138+
139+
| **Functions** | **Description** |
140+
|:------------|:----------------|
141+
| [`average_states(sol)`](@ref average_states) | Return the trajectory-averaged result states (as density [`Operator`](@ref)) at each time point. |
142+
| [`average_expect(sol)`](@ref average_expect) | Return the trajectory-averaged expectation values at each time point. |
143+
| [`std_expect(sol)`](@ref std_expect) | Return the trajectory-wise standard deviation of the expectation values at each time point. |

src/QuantumToolbox.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using LinearAlgebra
55
import LinearAlgebra: BlasInt, BlasFloat, checksquare
66
import LinearAlgebra.LAPACK: hseqr!
77
using SparseArrays
8+
import Statistics: mean, std
89

910
# SciML packages (for QobjEvo, OrdinaryDiffEq, and LinearSolve)
1011
import SciMLBase:

src/time_evolution/mcsolve.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222

2323
function _normalize_state!(u, dims, normalize_states)
2424
getVal(normalize_states) && normalize!(u)
25-
return QuantumObject(u, type = Ket(), dims = dims)
25+
return QuantumObject(u, Ket(), dims)
2626
end
2727

2828
function _mcsolve_make_Heff_QobjEvo(H::QuantumObject, c_ops)
@@ -278,6 +278,7 @@ end
278278
progress_bar::Union{Val,Bool} = Val(true),
279279
prob_func::Union{Function, Nothing} = nothing,
280280
output_func::Union{Tuple,Nothing} = nothing,
281+
keep_runs_results::Union{Val,Bool} = Val(false),
281282
normalize_states::Union{Val,Bool} = Val(true),
282283
kwargs...,
283284
)
@@ -332,6 +333,7 @@ If the environmental measurements register a quantum jump, the wave function und
332333
- `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
333334
- `prob_func`: Function to use for generating the ODEProblem.
334335
- `output_func`: a `Tuple` containing the `Function` to use for generating the output of a single trajectory, the (optional) `ProgressBar` object, and the (optional) `RemoteChannel` object.
336+
- `keep_runs_results`: Whether to save the results of each trajectory. Default to `Val(false)`.
335337
- `normalize_states`: Whether to normalize the states. Default to `Val(true)`.
336338
- `kwargs`: The keyword arguments for the ODEProblem.
337339
@@ -363,6 +365,7 @@ function mcsolve(
363365
progress_bar::Union{Val,Bool} = Val(true),
364366
prob_func::Union{Function,Nothing} = nothing,
365367
output_func::Union{Tuple,Nothing} = nothing,
368+
keep_runs_results::Union{Val,Bool} = Val(false),
366369
normalize_states::Union{Val,Bool} = Val(true),
367370
kwargs...,
368371
) where {TJC<:LindbladJumpCallbackType}
@@ -384,14 +387,15 @@ function mcsolve(
384387
kwargs...,
385388
)
386389

387-
return mcsolve(ens_prob_mc, alg, ntraj, ensemblealg, normalize_states)
390+
return mcsolve(ens_prob_mc, alg, ntraj, ensemblealg, makeVal(keep_runs_results), normalize_states)
388391
end
389392

390393
function mcsolve(
391394
ens_prob_mc::TimeEvolutionProblem,
392395
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
393396
ntraj::Int = 500,
394397
ensemblealg::EnsembleAlgorithm = EnsembleThreads(),
398+
keep_runs_results = Val(false),
395399
normalize_states = Val(true),
396400
)
397401
sol = _ensemble_dispatch_solve(ens_prob_mc, alg, ensemblealg, ntraj)
@@ -403,25 +407,24 @@ function mcsolve(
403407
_expvals_all =
404408
_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_expvals(sol[:, i], SaveFuncMCSolve), eachindex(sol))
405409
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP
406-
states = map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol))
410+
411+
# stack to transform Vector{Vector{QuantumObject}} -> Matrix{QuantumObject}
412+
states_all = stack(map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol)), dims = 1)
413+
407414
col_times = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_times, eachindex(sol))
408415
col_which = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_which, eachindex(sol))
409416

410-
expvals = _expvals_sol_1 isa Nothing ? nothing : dropdims(sum(expvals_all, dims = 2), dims = 2) ./ length(sol)
411-
412417
return TimeEvolutionMCSol(
413418
ntraj,
414419
ens_prob_mc.times,
415420
_sol_1.t,
416-
states,
417-
expvals,
418-
expvals, # This is average_expect
419-
expvals_all,
421+
_store_multitraj_states(states_all, keep_runs_results),
422+
_store_multitraj_expect(expvals_all, keep_runs_results),
420423
col_times,
421424
col_which,
422425
sol.converged,
423426
_sol_1.alg,
424-
NamedTuple(_sol_1.prob.kwargs).abstol,
425-
NamedTuple(_sol_1.prob.kwargs).reltol,
427+
_sol_1.prob.kwargs[:abstol],
428+
_sol_1.prob.kwargs[:reltol],
426429
)
427430
end

src/time_evolution/mesolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ function mesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit
211211
_get_expvals(sol, SaveFuncMESolve),
212212
sol.retcode,
213213
sol.alg,
214-
NamedTuple(sol.prob.kwargs).abstol,
215-
NamedTuple(sol.prob.kwargs).reltol,
214+
sol.prob.kwargs[:abstol],
215+
sol.prob.kwargs[:reltol],
216216
)
217217
end

src/time_evolution/sesolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ function sesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit
161161
_get_expvals(sol, SaveFuncSESolve),
162162
sol.retcode,
163163
sol.alg,
164-
NamedTuple(sol.prob.kwargs).abstol,
165-
NamedTuple(sol.prob.kwargs).reltol,
164+
sol.prob.kwargs[:abstol],
165+
sol.prob.kwargs[:reltol],
166166
)
167167
end

src/time_evolution/smesolve.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ end
293293
prob_func::Union{Function, Nothing} = nothing,
294294
output_func::Union{Tuple,Nothing} = nothing,
295295
progress_bar::Union{Val,Bool} = Val(true),
296+
keep_runs_results::Union{Val,Bool} = Val(false),
296297
store_measurement::Union{Val,Bool} = Val(false),
297298
kwargs...,
298299
)
@@ -333,6 +334,7 @@ Above, ``\hat{C}_i`` represent the collapse operators related to pure dissipatio
333334
- `prob_func`: Function to use for generating the SDEProblem.
334335
- `output_func`: a `Tuple` containing the `Function` to use for generating the output of a single trajectory, the (optional) `ProgressBar` object, and the (optional) `RemoteChannel` object.
335336
- `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
337+
- `keep_runs_results`: Whether to save the results of each trajectory. Default to `Val(false)`.
336338
- `store_measurement`: Whether to store the measurement expectation values. Default is `Val(false)`.
337339
- `kwargs`: The keyword arguments for the ODEProblem.
338340
@@ -365,6 +367,7 @@ function smesolve(
365367
prob_func::Union{Function,Nothing} = nothing,
366368
output_func::Union{Tuple,Nothing} = nothing,
367369
progress_bar::Union{Val,Bool} = Val(true),
370+
keep_runs_results::Union{Val,Bool} = Val(false),
368371
store_measurement::Union{Val,Bool} = Val(false),
369372
kwargs...,
370373
) where {StateOpType<:Union{Ket,Operator,OperatorKet}}
@@ -392,14 +395,15 @@ function smesolve(
392395
alg = sc_ops_isa_Qobj ? SRIW1() : SRA2()
393396
end
394397

395-
return smesolve(ensemble_prob, alg, ntraj, ensemblealg)
398+
return smesolve(ensemble_prob, alg, ntraj, ensemblealg, makeVal(keep_runs_results))
396399
end
397400

398401
function smesolve(
399402
ens_prob::TimeEvolutionProblem,
400403
alg::StochasticDiffEqAlgorithm = SRA2(),
401404
ntraj::Int = 500,
402405
ensemblealg::EnsembleAlgorithm = EnsembleThreads(),
406+
keep_runs_results = Val(false),
403407
)
404408
sol = _ensemble_dispatch_solve(ens_prob, alg, ensemblealg, ntraj)
405409

@@ -412,24 +416,22 @@ function smesolve(
412416
_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_expvals(sol[:, i], SaveFuncMESolve), eachindex(sol))
413417
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP
414418

415-
states = map(i -> _smesolve_generate_state.(sol[:, i].u, Ref(dims), ens_prob.kwargs.isoperket), eachindex(sol))
419+
# stack to transform Vector{Vector{QuantumObject}} -> Matrix{QuantumObject}
420+
states_all = stack(
421+
map(i -> _smesolve_generate_state.(sol[:, i].u, Ref(dims), ens_prob.kwargs.isoperket), eachindex(sol)),
422+
dims = 1,
423+
)
416424

417425
_m_expvals =
418426
_m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSMESolve), eachindex(sol))
419-
m_expvals = _m_expvals isa Nothing ? nothing : stack(_m_expvals, dims = 2)
420-
421-
expvals =
422-
_get_expvals(_sol_1, SaveFuncMESolve) isa Nothing ? nothing :
423-
dropdims(sum(expvals_all, dims = 2), dims = 2) ./ length(sol)
427+
m_expvals = _m_expvals isa Nothing ? nothing : stack(_m_expvals, dims = 2) # Stack on dimension 2 to align with QuTiP
424428

425429
return TimeEvolutionStochasticSol(
426430
ntraj,
427431
ens_prob.times,
428432
_sol_1.t,
429-
states,
430-
expvals,
431-
expvals, # This is average_expect
432-
expvals_all,
433+
_store_multitraj_states(states_all, keep_runs_results),
434+
_store_multitraj_expect(expvals_all, keep_runs_results),
433435
m_expvals, # Measurement expectation values
434436
sol.converged,
435437
_sol_1.alg,

src/time_evolution/ssesolve.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ end
285285
prob_func::Union{Function, Nothing} = nothing,
286286
output_func::Union{Tuple,Nothing} = nothing,
287287
progress_bar::Union{Val,Bool} = Val(true),
288+
keep_runs_results::Union{Val,Bool} = Val(false),
288289
store_measurement::Union{Val,Bool} = Val(false),
289290
kwargs...,
290291
)
@@ -328,6 +329,7 @@ Above, ``\hat{S}_n`` are the stochastic collapse operators and ``dW_n(t)`` is th
328329
- `prob_func`: Function to use for generating the SDEProblem.
329330
- `output_func`: a `Tuple` containing the `Function` to use for generating the output of a single trajectory, the (optional) `ProgressBar` object, and the (optional) `RemoteChannel` object.
330331
- `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
332+
- `keep_runs_results`: Whether to save the results of each trajectory. Default to `Val(false)`.
331333
- `store_measurement`: Whether to store the measurement results. Default is `Val(false)`.
332334
- `kwargs`: The keyword arguments for the ODEProblem.
333335
@@ -360,6 +362,7 @@ function ssesolve(
360362
prob_func::Union{Function,Nothing} = nothing,
361363
output_func::Union{Tuple,Nothing} = nothing,
362364
progress_bar::Union{Val,Bool} = Val(true),
365+
keep_runs_results::Union{Val,Bool} = Val(false),
363366
store_measurement::Union{Val,Bool} = Val(false),
364367
kwargs...,
365368
)
@@ -386,14 +389,15 @@ function ssesolve(
386389
alg = sc_ops_isa_Qobj ? SRIW1() : SRA2()
387390
end
388391

389-
return ssesolve(ens_prob, alg, ntraj, ensemblealg)
392+
return ssesolve(ens_prob, alg, ntraj, ensemblealg, makeVal(keep_runs_results))
390393
end
391394

392395
function ssesolve(
393396
ens_prob::TimeEvolutionProblem,
394397
alg::StochasticDiffEqAlgorithm = SRA2(),
395398
ntraj::Int = 500,
396399
ensemblealg::EnsembleAlgorithm = EnsembleThreads(),
400+
keep_runs_results = Val(false),
397401
)
398402
sol = _ensemble_dispatch_solve(ens_prob, alg, ensemblealg, ntraj)
399403

@@ -406,24 +410,20 @@ function ssesolve(
406410
_expvals_all =
407411
_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_expvals(sol[:, i], SaveFuncSSESolve), eachindex(sol))
408412
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP
409-
states = map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol))
413+
414+
# stack to transform Vector{Vector{QuantumObject}} -> Matrix{QuantumObject}
415+
states_all = stack(map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol)), dims = 1)
410416

411417
_m_expvals =
412418
_m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSSESolve), eachindex(sol))
413419
m_expvals = _m_expvals isa Nothing ? nothing : stack(_m_expvals, dims = 2)
414420

415-
expvals =
416-
_get_expvals(_sol_1, SaveFuncSSESolve) isa Nothing ? nothing :
417-
dropdims(sum(expvals_all, dims = 2), dims = 2) ./ length(sol)
418-
419421
return TimeEvolutionStochasticSol(
420422
ntraj,
421423
ens_prob.times,
422424
_sol_1.t,
423-
states,
424-
expvals,
425-
expvals, # This is average_expect
426-
expvals_all,
425+
_store_multitraj_states(states_all, keep_runs_results),
426+
_store_multitraj_expect(expvals_all, keep_runs_results),
427427
m_expvals, # Measurement expectation values
428428
sol.converged,
429429
_sol_1.alg,

0 commit comments

Comments
 (0)