Skip to content

Commit 905e658

Browse files
authored
fix incorrect times in time evolution solutions (#244)
1 parent 4df91fb commit 905e658

File tree

8 files changed

+53
-38
lines changed

8 files changed

+53
-38
lines changed

src/time_evolution/lr_mesolve.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -147,19 +147,19 @@ function _periodicsave_func(integrator)
147147
return u_modified!(integrator, false)
148148
end
149149

150-
_save_control_lr_mesolve(u, t, integrator) = t in integrator.p.t_l
150+
_save_control_lr_mesolve(u, t, integrator) = t in integrator.p.times
151151

152152
function _save_affect_lr_mesolve!(integrator)
153153
ip = integrator.p
154154
N, M = ip.N, ip.M
155-
idx = select(integrator.t, ip.t_l)
155+
idx = select(integrator.t, ip.times)
156156

157157
@views z = reshape(integrator.u[1:N*M], N, M)
158158
@views B = reshape(integrator.u[N*M+1:end], M, M)
159159
_calculate_expectation!(ip, z, B, idx)
160160

161161
if integrator.p.opt.progress
162-
print("\rProgress: $(round(Int, 100*idx/length(ip.t_l)))%")
162+
print("\rProgress: $(round(Int, 100*idx/length(ip.times)))%")
163163
flush(stdout)
164164
end
165165
return u_modified!(integrator, false)
@@ -365,7 +365,7 @@ end
365365
#=======================================================#
366366

367367
@doc raw"""
368-
lr_mesolveProblem(H, z, B, t_l, c_ops; e_ops=(), f_ops=(), opt=LRMesolveOptions(), kwargs...) where T
368+
lr_mesolveProblem(H, z, B, tlist, c_ops; e_ops=(), f_ops=(), opt=LRMesolveOptions(), kwargs...) where T
369369
Formulates the ODEproblem for the low-rank time evolution of the system. The function is called by lr_mesolve.
370370
371371
Parameters
@@ -376,7 +376,7 @@ end
376376
The initial z matrix.
377377
B : AbstractMatrix{T}
378378
The initial B matrix.
379-
t_l : AbstractVector{T}
379+
tlist : AbstractVector{T}
380380
The time steps at which the expectation values and function values are calculated.
381381
c_ops : AbstractVector{QuantumObject}
382382
The jump operators of the system.
@@ -393,7 +393,7 @@ function lr_mesolveProblem(
393393
H::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject},
394394
z::AbstractArray{T2,2},
395395
B::AbstractArray{T2,2},
396-
t_l::AbstractVector,
396+
tlist::AbstractVector,
397397
c_ops::AbstractVector = [];
398398
e_ops::Tuple = (),
399399
f_ops::Tuple = (),
@@ -407,6 +407,8 @@ function lr_mesolveProblem(
407407
c_ops = get_data.(c_ops)
408408
e_ops = get_data.(e_ops)
409409

410+
t_l = convert(Vector{_FType(H)}, tlist)
411+
410412
# Initialization of Arrays
411413
expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))
412414
funvals = Array{ComplexF64}(undef, length(f_ops), length(t_l))
@@ -421,7 +423,7 @@ function lr_mesolveProblem(
421423
e_ops = e_ops,
422424
f_ops = f_ops,
423425
opt = opt,
424-
t_l = t_l,
426+
times = t_l,
425427
expvals = expvals,
426428
funvals = funvals,
427429
Ml = Ml,
@@ -489,14 +491,14 @@ function lr_mesolve(
489491
H::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject},
490492
z::AbstractArray{T2,2},
491493
B::AbstractArray{T2,2},
492-
t_l::AbstractVector,
494+
tlist::AbstractVector,
493495
c_ops::AbstractVector = [];
494496
e_ops::Tuple = (),
495497
f_ops::Tuple = (),
496498
opt::LRMesolveOptions{AlgType} = LRMesolveOptions(),
497499
kwargs...,
498500
) where {T1,T2,AlgType<:OrdinaryDiffEqAlgorithm}
499-
prob = lr_mesolveProblem(H, z, B, t_l, c_ops; e_ops = e_ops, f_ops = f_ops, opt = opt, kwargs...)
501+
prob = lr_mesolveProblem(H, z, B, tlist, c_ops; e_ops = e_ops, f_ops = f_ops, opt = opt, kwargs...)
500502
return lr_mesolve(prob; kwargs...)
501503
end
502504

@@ -520,7 +522,7 @@ get_B(u::AbstractArray{T}, N::Integer, M::Integer) where {T} = reshape(view(u, (
520522
Additional keyword arguments for the ODEProblem.
521523
"""
522524
function lr_mesolve(prob::ODEProblem; kwargs...)
523-
sol = solve(prob, prob.p.opt.alg, tstops = prob.p.t_l)
525+
sol = solve(prob, prob.p.opt.alg, tstops = prob.p.times)
524526
prob.p.opt.progress && print("\n")
525527

526528
N = prob.p.N
@@ -535,5 +537,5 @@ function lr_mesolve(prob::ODEProblem; kwargs...)
535537
zt = get_z(sol.u, N, Ml)
536538
end
537539

538-
return LRTimeEvolutionSol(sol.t, zt, Bt, prob.p.expvals, prob.p.funvals, prob.p.Ml)
540+
return LRTimeEvolutionSol(sol.prob.p.times, zt, Bt, prob.p.expvals, prob.p.funvals, prob.p.Ml)
539541
end

src/time_evolution/mcsolve.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,12 @@ function _mcsolve_output_func(sol, i)
8686
return (sol, false)
8787
end
8888

89-
function _mcsolve_generate_statistics(sol, i, times, states, expvals_all, jump_times, jump_which)
89+
function _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which)
9090
sol_i = sol[:, i]
9191
!isempty(sol_i.prob.kwargs[:saveat]) ?
9292
states[i] = [QuantumObject(normalize!(sol_i.u[i]), dims = sol_i.prob.p.Hdims) for i in 1:length(sol_i.u)] : nothing
9393

9494
copyto!(view(expvals_all, i, :, :), sol_i.prob.p.expvals)
95-
times[i] = sol_i.t
9695
jump_times[i] = sol_i.prob.p.jump_times
9796
return jump_which[i] = sol_i.prob.p.jump_which
9897
end
@@ -522,22 +521,18 @@ function mcsolve(
522521
_sol_1 = sol[:, 1]
523522

524523
expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...)
525-
times = Vector{Vector{Float64}}(undef, length(sol))
526524
states =
527525
isempty(_sol_1.prob.kwargs[:saveat]) ? fill(QuantumObject[], length(sol)) :
528526
Vector{Vector{QuantumObject}}(undef, length(sol))
529527
jump_times = Vector{Vector{Float64}}(undef, length(sol))
530528
jump_which = Vector{Vector{Int16}}(undef, length(sol))
531529

532-
foreach(
533-
i -> _mcsolve_generate_statistics(sol, i, times, states, expvals_all, jump_times, jump_which),
534-
eachindex(sol),
535-
)
530+
foreach(i -> _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which), eachindex(sol))
536531
expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol)
537532

538533
return TimeEvolutionMCSol(
539534
ntraj,
540-
times,
535+
_sol_1.prob.p.times,
541536
states,
542537
expvals,
543538
expvals_all,

src/time_evolution/mesolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ function mesolveProblem(
150150
e_ops = e_ops2,
151151
expvals = expvals,
152152
H_t = H_t,
153+
times = t_l,
153154
is_empty_e_ops = is_empty_e_ops,
154155
params...,
155156
)
@@ -253,7 +254,7 @@ function mesolve(prob::ODEProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5())
253254
ρt = map-> QuantumObject(vec2mat(ϕ), type = Operator, dims = sol.prob.p.Hdims), sol.u)
254255

255256
return TimeEvolutionSol(
256-
sol.t,
257+
sol.prob.p.times,
257258
ρt,
258259
sol.prob.p.expvals,
259260
sol.retcode,

src/time_evolution/sesolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ function sesolveProblem(
127127
progr = progr,
128128
Hdims = H.dims,
129129
H_t = H_t,
130+
times = t_l,
130131
is_empty_e_ops = is_empty_e_ops,
131132
params...,
132133
)
@@ -215,7 +216,7 @@ function sesolve(prob::ODEProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5())
215216
ψt = map-> QuantumObject(ϕ, type = Ket, dims = sol.prob.p.Hdims), sol.u)
216217

217218
return TimeEvolutionSol(
218-
sol.t,
219+
sol.prob.p.times,
219220
ψt,
220221
sol.prob.p.expvals,
221222
sol.retcode,

src/time_evolution/ssesolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ function ssesolveProblem(
180180
progr = progr,
181181
Hdims = H.dims,
182182
H_t = H_t,
183+
times = t_l,
183184
is_empty_e_ops = is_empty_e_ops,
184185
params...,
185186
)
@@ -404,7 +405,7 @@ function ssesolve(
404405

405406
return TimeEvolutionSSESol(
406407
ntraj,
407-
_sol_1.t,
408+
_sol_1.prob.p.times,
408409
states,
409410
expvals,
410411
expvals_all,

src/time_evolution/time_evolution.jl

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ A structure storing the results and some information from solving quantum trajec
5151
# Fields (Attributes)
5252
5353
- `ntraj::Int`: Number of trajectories
54-
- `times::AbstractVector`: The time list of the evolution in each trajectory.
54+
- `times::AbstractVector`: The time list of the evolution.
5555
- `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory.
5656
- `expect::Matrix`: The expectation values (averaging all trajectories) corresponding to each time point in `times`.
5757
- `expect_all::Array`: The expectation values corresponding to each trajectory and each time point in `times`
@@ -63,7 +63,7 @@ A structure storing the results and some information from solving quantum trajec
6363
- `reltol::Real`: The relative tolerance which is used during the solving process.
6464
"""
6565
struct TimeEvolutionMCSol{
66-
TT<:Vector{<:Vector{<:Real}},
66+
TT<:Vector{<:Real},
6767
TS<:AbstractVector,
6868
TE<:Matrix{ComplexF64},
6969
TEA<:Array{ComplexF64,3},
@@ -97,19 +97,22 @@ function Base.show(io::IO, sol::TimeEvolutionMCSol)
9797
end
9898

9999
@doc raw"""
100-
struct TimeEvolutionSSESol
101-
A structure storing the results and some information from solving trajectories of the Stochastic Shrodinger equation time evolution.
102-
# Fields (Attributes)
103-
- `ntraj::Int`: Number of trajectories
104-
- `times::AbstractVector`: The time list of the evolution in each trajectory.
105-
- `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory.
106-
- `expect::Matrix`: The expectation values (averaging all trajectories) corresponding to each time point in `times`.
107-
- `expect_all::Array`: The expectation values corresponding to each trajectory and each time point in `times`
108-
- `converged::Bool`: Whether the solution is converged or not.
109-
- `alg`: The algorithm which is used during the solving process.
110-
- `abstol::Real`: The absolute tolerance which is used during the solving process.
111-
- `reltol::Real`: The relative tolerance which is used during the solving process.
112-
"""
100+
struct TimeEvolutionSSESol
101+
102+
A structure storing the results and some information from solving trajectories of the Stochastic Shrodinger equation time evolution.
103+
104+
# Fields (Attributes)
105+
106+
- `ntraj::Int`: Number of trajectories
107+
- `times::AbstractVector`: The time list of the evolution.
108+
- `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory.
109+
- `expect::Matrix`: The expectation values (averaging all trajectories) corresponding to each time point in `times`.
110+
- `expect_all::Array`: The expectation values corresponding to each trajectory and each time point in `times`
111+
- `converged::Bool`: Whether the solution is converged or not.
112+
- `alg`: The algorithm which is used during the solving process.
113+
- `abstol::Real`: The absolute tolerance which is used during the solving process.
114+
- `reltol::Real`: The relative tolerance which is used during the solving process.
115+
"""
113116
struct TimeEvolutionSSESol{
114117
TT<:Vector{<:Real},
115118
TS<:AbstractVector,

src/time_evolution/time_evolution_dynamical.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ function dfd_mesolve(
248248
)
249249

250250
return TimeEvolutionSol(
251-
sol.t,
251+
sol.prob.p.times,
252252
ρt,
253253
sol.prob.p.expvals,
254254
sol.retcode,

test/core-test/time_evolution.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
sol3 = sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
1717
sol_string = sprint((t, s) -> show(t, "text/plain", s), sol)
1818
@test sum(abs.(sol.expect[1, :] .- sin.(η * t_l) .^ 2)) / length(t_l) < 0.1
19+
@test length(sol.times) == length(t_l)
1920
@test length(sol.states) == 1
2021
@test size(sol.expect) == (length(e_ops), length(t_l))
22+
@test length(sol2.times) == length(t_l)
2123
@test length(sol2.states) == length(t_l)
2224
@test size(sol2.expect) == (0, length(t_l))
25+
@test length(sol3.times) == length(t_l)
2326
@test length(sol3.states) == length(t_l)
2427
@test size(sol3.expect) == (length(e_ops), length(t_l))
2528
@test sol_string ==
@@ -68,12 +71,21 @@
6871
@test sum(abs.(sol_mc.expect .- sol_me.expect)) / length(t_l) < 0.1
6972
@test sum(abs.(vec(expect_mc_states_mean) .- vec(sol_me.expect))) / length(t_l) < 0.1
7073
@test sum(abs.(sol_sse.expect .- sol_me.expect)) / length(t_l) < 0.1
74+
@test length(sol_me.times) == length(t_l)
7175
@test length(sol_me.states) == 1
7276
@test size(sol_me.expect) == (length(e_ops), length(t_l))
77+
@test length(sol_me2.times) == length(t_l)
7378
@test length(sol_me2.states) == length(t_l)
7479
@test size(sol_me2.expect) == (0, length(t_l))
80+
@test length(sol_me3.times) == length(t_l)
7581
@test length(sol_me3.states) == length(t_l)
7682
@test size(sol_me3.expect) == (length(e_ops), length(t_l))
83+
@test length(sol_mc.times) == length(t_l)
84+
@test size(sol_mc.expect) == (length(e_ops), length(t_l))
85+
@test length(sol_mc_states.times) == length(t_l)
86+
@test size(sol_mc_states.expect) == (0, length(t_l))
87+
@test length(sol_sse.times) == length(t_l)
88+
@test size(sol_sse.expect) == (length(e_ops), length(t_l))
7789
@test sol_me_string ==
7890
"Solution of time evolution\n" *
7991
"(return code: $(sol_me.retcode))\n" *

0 commit comments

Comments
 (0)