Skip to content

Commit 45495be

Browse files
Make sesolve working
1 parent 73fe861 commit 45495be

File tree

2 files changed

+32
-25
lines changed

2 files changed

+32
-25
lines changed

src/time_evolution/sesolve.jl

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,24 @@ function _save_func_sesolve(integrator)
1616
return u_modified!(integrator, false)
1717
end
1818

19-
function _generate_sesolve_kwargs_with_callback(t_l, kwargs)
20-
cb1 = PresetTimeCallback(t_l, _save_func_sesolve, save_positions = (false, false))
19+
function _generate_sesolve_kwargs_with_callback(tlist, kwargs)
20+
cb1 = PresetTimeCallback(tlist, _save_func_sesolve, save_positions = (false, false))
2121
kwargs2 =
2222
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(kwargs.callback, cb1),)) :
2323
merge(kwargs, (callback = cb1,))
2424

2525
return kwargs2
2626
end
2727

28-
function _generate_sesolve_kwargs(e_ops, progress_bar::Val{true}, t_l, kwargs)
29-
return _generate_sesolve_kwargs_with_callback(t_l, kwargs)
28+
function _generate_sesolve_kwargs(e_ops, progress_bar::Val{true}, tlist, kwargs)
29+
return _generate_sesolve_kwargs_with_callback(tlist, kwargs)
3030
end
3131

32-
function _generate_sesolve_kwargs(e_ops, progress_bar::Val{false}, t_l, kwargs)
32+
function _generate_sesolve_kwargs(e_ops, progress_bar::Val{false}, tlist, kwargs)
3333
if e_ops isa Nothing
3434
return kwargs
3535
end
36-
return _generate_sesolve_kwargs_with_callback(t_l, kwargs)
36+
return _generate_sesolve_kwargs_with_callback(tlist, kwargs)
3737
end
3838

3939
@doc raw"""
@@ -88,23 +88,23 @@ function sesolveProblem(
8888
haskey(kwargs, :save_idxs) &&
8989
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))
9090

91-
ϕ0 = sparse_to_dense(_CType(ψ0), get_data(ψ0)) # Convert it to dense vector with complex element type
92-
93-
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
91+
tlist = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
9492

9593
H_evo = QobjEvo(H, -1im) # pre-multiply by -i
9694
isoper(H_evo) || throw(ArgumentError("The Hamiltonian must be an Operator."))
9795
check_dims(H_evo, ψ0)
96+
97+
ψ0 = sparse_to_dense(_CType(ψ0), get_data(ψ0)) # Convert it to dense vector with complex element type
9898
U = H_evo.data
9999

100-
progr = ProgressBar(length(t_l), enable = getVal(progress_bar))
100+
progr = ProgressBar(length(tlist), enable = getVal(progress_bar))
101101

102102
if e_ops isa Nothing
103-
expvals = Array{ComplexF64}(undef, 0, length(t_l))
103+
expvals = Array{ComplexF64}(undef, 0, length(tlist))
104104
e_ops_data = ()
105105
is_empty_e_ops = true
106106
else
107-
expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))
107+
expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist))
108108
e_ops_data = get_data.(e_ops)
109109
is_empty_e_ops = isempty(e_ops)
110110
end
@@ -114,18 +114,17 @@ function sesolveProblem(
114114
expvals = expvals,
115115
progr = progr,
116116
Hdims = H_evo.dims,
117-
times = t_l,
118117
is_empty_e_ops = is_empty_e_ops,
119118
params...,
120119
)
121120

122-
saveat = is_empty_e_ops ? t_l : [t_l[end]]
121+
saveat = is_empty_e_ops ? tlist : [tlist[end]]
123122
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
124123
kwargs2 = merge(default_values, kwargs)
125-
kwargs3 = _generate_sesolve_kwargs(e_ops, makeVal(progress_bar), t_l, kwargs2)
124+
kwargs3 = _generate_sesolve_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2)
126125

127-
tspan = (t_l[1], t_l[end])
128-
return ODEProblem{true,FullSpecialize}(U, ϕ0, tspan, p; kwargs3...)
126+
tspan = (tlist[1], tlist[end])
127+
return ODEProblem{true,FullSpecialize}(U, ψ0, tspan, p; kwargs3...)
129128
end
130129

131130
@doc raw"""
@@ -181,16 +180,16 @@ function sesolve(
181180
) where {DT1,DT2}
182181
prob = sesolveProblem(H, ψ0, tlist; e_ops = e_ops, params = params, progress_bar = progress_bar, kwargs...)
183182

184-
return sesolve(prob, alg)
183+
return sesolve(prob, tlist, alg)
185184
end
186185

187-
function sesolve(prob::ODEProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5())
186+
function sesolve(prob::ODEProblem, tlist::AbstractVector, alg::OrdinaryDiffEqAlgorithm = Tsit5())
188187
sol = solve(prob, alg)
189188

190189
ψt = map-> QuantumObject(ϕ, type = Ket, dims = sol.prob.p.Hdims), sol.u)
191190

192191
return TimeEvolutionSol(
193-
sol.prob.p.times,
192+
tlist,
194193
ψt,
195194
sol.prob.p.expvals,
196195
sol.retcode,

src/time_evolution/time_evolution.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,22 @@ A structure storing the results and some information from solving time evolution
2121
- `abstol::Real`: The absolute tolerance which is used during the solving process.
2222
- `reltol::Real`: The relative tolerance which is used during the solving process.
2323
"""
24-
struct TimeEvolutionSol{TT<:Vector{<:Real},TS<:AbstractVector,TE<:Matrix{ComplexF64}}
24+
struct TimeEvolutionSol{
25+
TT<:AbstractVector{<:Real},
26+
TS<:AbstractVector,
27+
TE<:Matrix,
28+
RETT<:Enum,
29+
AlgT<:OrdinaryDiffEqAlgorithm,
30+
AT<:Real,
31+
RT<:Real,
32+
}
2533
times::TT
2634
states::TS
2735
expect::TE
28-
retcode::Enum
29-
alg::OrdinaryDiffEqAlgorithm
30-
abstol::Real
31-
reltol::Real
36+
retcode::RETT
37+
alg::AlgT
38+
abstol::AT
39+
reltol::RT
3240
end
3341

3442
function Base.show(io::IO, sol::TimeEvolutionSol)

0 commit comments

Comments
 (0)