Skip to content

Commit b71a0e3

Browse files
committed
align statistical analysis methods for multi-trajectory with qutip
1 parent 5e41ed3 commit b71a0e3

File tree

5 files changed

+59
-39
lines changed

5 files changed

+59
-39
lines changed

docs/src/resources/api.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,9 @@ TimeEvolutionProblem
192192
TimeEvolutionSol
193193
TimeEvolutionMCSol
194194
TimeEvolutionStochasticSol
195+
average_states
196+
average_expect
197+
std_expect
195198
sesolveProblem
196199
mesolveProblem
197200
mcsolveProblem

src/time_evolution/mcsolve.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -407,14 +407,10 @@ function mcsolve(
407407
col_times = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_times, eachindex(sol))
408408
col_which = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_which, eachindex(sol))
409409

410-
expvals = _expvals_sol_1 isa Nothing ? nothing : dropdims(sum(expvals_all, dims = 2), dims = 2) ./ length(sol)
411-
412410
return TimeEvolutionMCSol(
413411
ntraj,
414412
ens_prob_mc.times,
415413
states,
416-
expvals,
417-
expvals, # This is average_expect
418414
expvals_all,
419415
col_times,
420416
col_which,

src/time_evolution/smesolve.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -418,16 +418,10 @@ function smesolve(
418418
_m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSMESolve), eachindex(sol))
419419
m_expvals = _m_expvals isa Nothing ? nothing : stack(_m_expvals, dims = 2)
420420

421-
expvals =
422-
_get_expvals(_sol_1, SaveFuncMESolve) isa Nothing ? nothing :
423-
dropdims(sum(expvals_all, dims = 2), dims = 2) ./ length(sol)
424-
425421
return TimeEvolutionStochasticSol(
426422
ntraj,
427423
ens_prob.times,
428424
states,
429-
expvals,
430-
expvals, # This is average_expect
431425
expvals_all,
432426
m_expvals, # Measurement expectation values
433427
sol.converged,

src/time_evolution/ssesolve.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -412,16 +412,10 @@ function ssesolve(
412412
_m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSSESolve), eachindex(sol))
413413
m_expvals = _m_expvals isa Nothing ? nothing : stack(_m_expvals, dims = 2)
414414

415-
expvals =
416-
_get_expvals(_sol_1, SaveFuncSSESolve) isa Nothing ? nothing :
417-
dropdims(sum(expvals_all, dims = 2), dims = 2) ./ length(sol)
418-
419415
return TimeEvolutionStochasticSol(
420416
ntraj,
421417
ens_prob.times,
422418
states,
423-
expvals,
424-
expvals, # This is average_expect
425419
expvals_all,
426420
m_expvals, # Measurement expectation values
427421
sol.converged,

src/time_evolution/time_evolution.jl

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
export TimeEvolutionSol, TimeEvolutionMCSol, TimeEvolutionStochasticSol
1+
export TimeEvolutionSol
2+
export TimeEvolutionMultiTrajSol, TimeEvolutionMCSol, TimeEvolutionStochasticSol
3+
export average_states, average_expect, std_expect
24

35
export liouvillian_floquet, liouvillian_generalized
46

57
const DEFAULT_ODE_SOLVER_OPTIONS = (abstol = 1e-8, reltol = 1e-6, save_everystep = false, save_end = true)
68
const DEFAULT_SDE_SOLVER_OPTIONS = (abstol = 1e-3, reltol = 2e-3, save_everystep = false, save_end = true)
79
const COL_TIMES_WHICH_INIT_SIZE = 200
810

11+
abstract type TimeEvolutionMultiTrajSol{Texpect} end
12+
913
@doc raw"""
1014
struct TimeEvolutionProblem
1115
@@ -89,42 +93,45 @@ function Base.show(io::IO, sol::TimeEvolutionSol)
8993
end
9094

9195
@doc raw"""
92-
struct TimeEvolutionMCSol
96+
struct TimeEvolutionMCSol <: TimeEvolutionMultiTrajSol
9397
9498
A structure storing the results and some information from solving quantum trajectories of the Monte Carlo wave function time evolution.
9599
96100
# Fields (Attributes)
97101
98102
- `ntraj::Int`: Number of trajectories
99103
- `times::AbstractVector`: The time list of the evolution.
100-
- `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory.
101-
- `expect::Union{AbstractMatrix,Nothing}`: The expectation values (averaging all trajectories) corresponding to each time point in `times`.
102-
- `average_expect::Union{AbstractMatrix,Nothing}`: The expectation values (averaging all trajectories) corresponding to each time point in `times`.
103-
- `runs_expect::Union{AbstractArray,Nothing}`: The expectation values corresponding to each trajectory and each time point in `times`
104+
- `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory and each time point in `times`.
105+
- `expect::Union{AbstractArray,Nothing}`: The expectation values corresponding to each trajectory and each time point in `times`.
104106
- `col_times::Vector{Vector{Real}}`: The time records of every quantum jump occurred in each trajectory.
105107
- `col_which::Vector{Vector{Int}}`: The indices of which collapse operator was responsible for each quantum jump in `col_times`.
106108
- `converged::Bool`: Whether the solution is converged or not.
107109
- `alg`: The algorithm which is used during the solving process.
108110
- `abstol::Real`: The absolute tolerance which is used during the solving process.
109111
- `reltol::Real`: The relative tolerance which is used during the solving process.
112+
113+
# Methods
114+
115+
We also provide the following functions for analyzing statistics from multi-trajectory solutions.
116+
117+
- [`average_states`](@ref)
118+
- [`average_expect`](@ref)
119+
- [`std_expect`](@ref)
110120
"""
111121
struct TimeEvolutionMCSol{
112122
TT<:AbstractVector{<:Real},
113123
TS<:AbstractVector,
114-
TE<:Union{AbstractMatrix,Nothing},
115-
TEA<:Union{AbstractArray,Nothing},
124+
TE<:Union{AbstractArray,Nothing},
116125
TJT<:Vector{<:Vector{<:Real}},
117126
TJW<:Vector{<:Vector{<:Integer}},
118127
AlgT<:OrdinaryDiffEqAlgorithm,
119128
AT<:Real,
120129
RT<:Real,
121-
}
130+
} <: TimeEvolutionMultiTrajSol{TE}
122131
ntraj::Int
123132
times::TT
124133
states::TS
125134
expect::TE
126-
average_expect::TE # Currently just a synonym for `expect`
127-
runs_expect::TEA
128135
col_times::TJT
129136
col_which::TJW
130137
converged::Bool
@@ -142,7 +149,7 @@ function Base.show(io::IO, sol::TimeEvolutionMCSol)
142149
if sol.expect isa Nothing
143150
print(io, "num_expect = 0\n")
144151
else
145-
print(io, "num_expect = $(size(sol.average_expect, 1))\n")
152+
print(io, "num_expect = $(size(sol.expect, 1))\n")
146153
end
147154
print(io, "ODE alg.: $(sol.alg)\n")
148155
print(io, "abstol = $(sol.abstol)\n")
@@ -151,39 +158,42 @@ function Base.show(io::IO, sol::TimeEvolutionMCSol)
151158
end
152159

153160
@doc raw"""
154-
struct TimeEvolutionStochasticSol
161+
struct TimeEvolutionStochasticSol <: TimeEvolutionMultiTrajSol
155162
156163
A structure storing the results and some information from solving trajectories of the Stochastic time evolution.
157164
158165
# Fields (Attributes)
159166
160167
- `ntraj::Int`: Number of trajectories
161168
- `times::AbstractVector`: The time list of the evolution.
162-
- `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory.
163-
- `expect::Union{AbstractMatrix,Nothing}`: The expectation values (averaging all trajectories) corresponding to each time point in `times`.
164-
- `average_expect::Union{AbstractMatrix,Nothing}`: The expectation values (averaging all trajectories) corresponding to each time point in `times`.
165-
- `runs_expect::Union{AbstractArray,Nothing}`: The expectation values corresponding to each trajectory and each time point in `times`
169+
- `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory and each time point in `times`.
170+
- `expect::Union{AbstractArray,Nothing}`: The expectation values corresponding to each trajectory and each time point in `times`.
166171
- `converged::Bool`: Whether the solution is converged or not.
167172
- `alg`: The algorithm which is used during the solving process.
168173
- `abstol::Real`: The absolute tolerance which is used during the solving process.
169174
- `reltol::Real`: The relative tolerance which is used during the solving process.
175+
176+
# Methods
177+
178+
We also provide the following functions for analyzing statistics from multi-trajectory solutions.
179+
180+
- [`average_states`](@ref)
181+
- [`average_expect`](@ref)
182+
- [`std_expect`](@ref)
170183
"""
171184
struct TimeEvolutionStochasticSol{
172185
TT<:AbstractVector{<:Real},
173186
TS<:AbstractVector,
174-
TE<:Union{AbstractMatrix,Nothing},
175-
TEA<:Union{AbstractArray,Nothing},
187+
TE<:Union{AbstractArray,Nothing},
176188
TEM<:Union{AbstractArray,Nothing},
177189
AlgT<:StochasticDiffEqAlgorithm,
178190
AT<:Real,
179191
RT<:Real,
180-
}
192+
} <: TimeEvolutionMultiTrajSol{TE}
181193
ntraj::Int
182194
times::TT
183195
states::TS
184196
expect::TE
185-
average_expect::TE # Currently just a synonym for `expect`
186-
runs_expect::TEA
187197
measurement::TEM
188198
converged::Bool
189199
alg::AlgT
@@ -200,14 +210,37 @@ function Base.show(io::IO, sol::TimeEvolutionStochasticSol)
200210
if sol.expect isa Nothing
201211
print(io, "num_expect = 0\n")
202212
else
203-
print(io, "num_expect = $(size(sol.average_expect, 1))\n")
213+
print(io, "num_expect = $(size(sol.expect, 1))\n")
204214
end
205215
print(io, "SDE alg.: $(sol.alg)\n")
206216
print(io, "abstol = $(sol.abstol)\n")
207217
print(io, "reltol = $(sol.reltol)\n")
208218
return nothing
209219
end
210220

221+
@doc raw"""
222+
average_states(sol::TimeEvolutionMultiTrajSol)
223+
224+
Return the trajectory-averaged result states at each time point.
225+
"""
226+
average_states(sol::TimeEvolutionMultiTrajSol) = mean(sol.states)
227+
228+
@doc raw"""
229+
average_expect(sol::TimeEvolutionMultiTrajSol)
230+
231+
Return the trajectory-averaged expectation values at each time point.
232+
"""
233+
average_expect(sol::TimeEvolutionMultiTrajSol{TE}) where {TE<:AbstractArray} = dropdims(mean(sol.runs_expect, dims = 2), dims = 2)
234+
average_expect(sol::TimeEvolutionMultiTrajSol{Nothing}) = nothing
235+
236+
@doc raw"""
237+
std_expect(sol::TimeEvolutionMultiTrajSol)
238+
239+
Return the trajectory-wise standard deviation of the expectation values at each time point.
240+
"""
241+
std_expect(sol::TimeEvolutionMultiTrajSol{TE}) where {TE<:AbstractArray} = dropdims(std(sol.runs_expect, dims = 2), dims = 2)
242+
std_expect(sol::TimeEvolutionMultiTrajSol{Nothing}) = nothing
243+
211244
#######################################
212245
#=
213246
Callbacks for Monte Carlo quantum trajectories

0 commit comments

Comments
 (0)