Skip to content

Commit 958fe42

Browse files
Working sesolve
1 parent c8c9479 commit 958fe42

File tree

4 files changed

+98
-45
lines changed

4 files changed

+98
-45
lines changed

src/time_evolution/mcsolve.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ function mcsolveProblem(
270270
jump_which = jump_which,
271271
jump_times_which_init_size = jump_times_which_init_size,
272272
jump_times_which_idx = Ref(1),
273+
times = tlist, # Temporary fix
274+
Hdims = H_eff_evo.dims, # Temporary fix
273275
params...,
274276
)
275277

@@ -291,7 +293,7 @@ function mcsolveProblem(
291293
haskey(kwargs2, :callback) ? merge(kwargs2, (callback = CallbackSet(cb1, cb2, kwargs2.callback),)) :
292294
merge(kwargs2, (callback = CallbackSet(cb1, cb2),))
293295

294-
return sesolveProblem(H_eff_evo, ψ0, tlist; params = params, kwargs2...)
296+
return sesolveProblem(H_eff_evo, ψ0, tlist; params = params, kwargs2...).prob # Temporary fix
295297
end
296298

297299
function mcsolveProblem(
@@ -315,7 +317,7 @@ function mcsolveProblem(
315317
haskey(kwargs2, :callback) ? merge(kwargs2, (callback = CallbackSet(cb1, cb2, kwargs2.callback),)) :
316318
merge(kwargs2, (callback = CallbackSet(cb1, cb2),))
317319

318-
return sesolveProblem(H_eff_evo, ψ0, tlist; params = params, kwargs2...)
320+
return sesolveProblem(H_eff_evo, ψ0, tlist; params = params, kwargs2...).prob # Temporary fix
319321
end
320322

321323
@doc raw"""

src/time_evolution/sesolve.jl

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,55 @@
11
export sesolveProblem, sesolve
22

3+
# When e_ops is Nothing
34
function _save_func_sesolve(integrator)
4-
internal_params = integrator.p
5-
progr = internal_params.progr
6-
7-
if !internal_params.is_empty_e_ops
8-
e_ops = internal_params.e_ops
9-
expvals = internal_params.expvals
5+
next!(integrator.p.progr)
6+
return u_modified!(integrator, false)
7+
end
108

9+
# When e_ops is a list of operators
10+
function _save_func_sesolve(integrator, e_ops, is_empty_e_ops)
11+
expvals = integrator.p.expvals
12+
progr = integrator.p.progr
13+
if !is_empty_e_ops
1114
ψ = integrator.u
12-
_expect = op -> dot(ψ, op, ψ)
15+
_expect = op -> dot(ψ, get_data(op), ψ)
1316
@. expvals[:, progr.counter[]+1] = _expect(e_ops)
1417
end
15-
next!(progr)
16-
return u_modified!(integrator, false)
18+
return _save_func_sesolve(integrator)
19+
end
20+
21+
# Generate the callback depending on the e_ops type
22+
function _generate_sesolve_callback(e_ops::Nothing, tlist)
23+
f = integrator -> _save_func_sesolve(integrator)
24+
return PresetTimeCallback(tlist, f, save_positions = (false, false))
25+
end
26+
27+
function _generate_sesolve_callback(e_ops, tlist)
28+
is_empty_e_ops = isempty(e_ops)
29+
f = integrator -> _save_func_sesolve(integrator, e_ops, is_empty_e_ops)
30+
return PresetTimeCallback(tlist, f, save_positions = (false, false))
1731
end
1832

19-
function _generate_sesolve_kwargs_with_callback(tlist, kwargs)
20-
cb1 = PresetTimeCallback(tlist, _save_func_sesolve, save_positions = (false, false))
33+
function _merge_sesolve_kwargs_with_callback(kwargs, cb)
2134
kwargs2 =
22-
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(kwargs.callback, cb1),)) :
23-
merge(kwargs, (callback = cb1,))
35+
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(kwargs.callback, cb),)) :
36+
merge(kwargs, (callback = cb,))
2437

2538
return kwargs2
2639
end
2740

41+
# Multiple dispatch depending on the progress_bar and e_ops types
2842
function _generate_sesolve_kwargs(e_ops, progress_bar::Val{true}, tlist, kwargs)
29-
return _generate_sesolve_kwargs_with_callback(tlist, kwargs)
43+
cb = _generate_sesolve_callback(e_ops, tlist)
44+
return _merge_sesolve_kwargs_with_callback(kwargs, cb)
3045
end
3146

3247
function _generate_sesolve_kwargs(e_ops, progress_bar::Val{false}, tlist, kwargs)
3348
if e_ops isa Nothing
3449
return kwargs
3550
end
36-
return _generate_sesolve_kwargs_with_callback(tlist, kwargs)
51+
cb = _generate_sesolve_callback(e_ops, tlist)
52+
return _merge_sesolve_kwargs_with_callback(kwargs, cb)
3753
end
3854

3955
_sesolve_make_U_QobjEvo(H::QuantumObjectEvolution{<:MatrixOperator}) =
@@ -103,31 +119,23 @@ function sesolveProblem(
103119

104120
if e_ops isa Nothing
105121
expvals = Array{ComplexF64}(undef, 0, length(tlist))
106-
e_ops_data = ()
107122
is_empty_e_ops = true
108123
else
109124
expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist))
110-
e_ops_data = get_data.(e_ops)
111125
is_empty_e_ops = isempty(e_ops)
112126
end
113127

114-
p = (
115-
e_ops = e_ops_data,
116-
expvals = expvals,
117-
progr = progr,
118-
times = tlist,
119-
Hdims = H_evo.dims,
120-
is_empty_e_ops = is_empty_e_ops,
121-
params...,
122-
)
128+
p = QuantumTimeEvoParameters(expvals, progr, params)
123129

124130
saveat = is_empty_e_ops ? tlist : [tlist[end]]
125131
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
126132
kwargs2 = merge(default_values, kwargs)
127133
kwargs3 = _generate_sesolve_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2)
128134

129135
tspan = (tlist[1], tlist[end])
130-
return ODEProblem{true,FullSpecialize}(U, ψ0, tspan, p; kwargs3...)
136+
prob = ODEProblem{true,FullSpecialize}(U, ψ0, tspan, p; kwargs3...)
137+
138+
return QuantumTimeEvoProblem(prob, tlist, H_evo.dims)
131139
end
132140

133141
@doc raw"""
@@ -186,13 +194,13 @@ function sesolve(
186194
return sesolve(prob, alg)
187195
end
188196

189-
function sesolve(prob::ODEProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5())
190-
sol = solve(prob, alg)
197+
function sesolve(prob::QuantumTimeEvoProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5())
198+
sol = solve(prob.prob, alg)
191199

192-
ψt = map-> QuantumObject(ϕ, type = Ket, dims = sol.prob.p.Hdims), sol.u)
200+
ψt = map-> QuantumObject(ϕ, type = Ket, dims = prob.dims), sol.u)
193201

194202
return TimeEvolutionSol(
195-
sol.prob.p.times,
203+
prob.times,
196204
ψt,
197205
sol.prob.p.expvals,
198206
sol.retcode,

src/time_evolution/ssesolve.jl

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,9 @@ function ssesolveProblem(
184184

185185
if e_ops isa Nothing
186186
expvals = Array{ComplexF64}(undef, 0, length(tlist))
187-
e_ops_data = ()
188187
is_empty_e_ops = true
189188
else
190189
expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist))
191-
e_ops_data = get_data.(e_ops)
192190
is_empty_e_ops = isempty(e_ops)
193191
end
194192

@@ -205,16 +203,7 @@ function ssesolveProblem(
205203
D_l = map(op -> op + _ScalarOperator_e(op, -) * IdentityOperator(prod(dims)), sc_ops_evo_data)
206204
D = DiffusionOperator(D_l)
207205

208-
p = (
209-
e_ops = e_ops_data,
210-
expvals = expvals,
211-
progr = progr,
212-
times = tlist,
213-
Hdims = dims,
214-
is_empty_e_ops = is_empty_e_ops,
215-
n_sc_ops = length(sc_ops),
216-
params...,
217-
)
206+
p = (expvals = expvals, progr = progr, times = tlist, Hdims = dims, n_sc_ops = length(sc_ops), params...)
218207

219208
saveat = is_empty_e_ops ? tlist : [tlist[end]]
220209
default_values = (DEFAULT_SDE_SOLVER_OPTIONS..., saveat = saveat)

src/time_evolution/time_evolution.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,60 @@ export liouvillian_floquet, liouvillian_generalized
55
const DEFAULT_ODE_SOLVER_OPTIONS = (abstol = 1e-8, reltol = 1e-6, save_everystep = false, save_end = true)
66
const DEFAULT_SDE_SOLVER_OPTIONS = (abstol = 1e-2, reltol = 1e-2, save_everystep = false, save_end = true)
77

8+
# This function should be implemented after Julia v1.12
9+
Base.@constprop :aggressive function _delete_field(a::NamedTuple{an}, field::Symbol) where {an}
10+
names = Base.diff_names(an, (field,))
11+
return NamedTuple{names}(a)
12+
end
13+
14+
struct QuantumTimeEvoParameters{TE<:AbstractMatrix,PT<:ProgressBar,ParT}
15+
expvals::TE
16+
progr::PT
17+
params::ParT
18+
19+
function QuantumTimeEvoParameters(expvals, progr, params)
20+
_expvals = expvals
21+
_progr = progr
22+
_params = params
23+
24+
# We replace the fields if they are aleady in the `params` struct
25+
# Then, we remove them from the `params` struct
26+
if :expvals fieldnames(typeof(_params))
27+
_expvals = _params.expvals
28+
_params = _delete_field(_params, :expvals)
29+
end
30+
if :progr fieldnames(typeof(_params))
31+
_progr = _params.progr
32+
_params = _delete_field(_params, :progr)
33+
end
34+
35+
return new{typeof(_expvals),typeof(_progr),typeof(_params)}(_expvals, _progr, _params)
36+
end
37+
end
38+
39+
#=
40+
By defining a custom `getproperty` method for the `QuantumTimeEvoParameters` struct, we can access the fields of `params` directly.
41+
=#
42+
function Base.getproperty(obj::QuantumTimeEvoParameters, field::Symbol)
43+
if field fieldnames(typeof(obj))
44+
getfield(obj, field)
45+
elseif field fieldnames(typeof(obj.params))
46+
getfield(obj.params, field)
47+
else
48+
throw(KeyError("Field $field not found in QuantumTimeEvoParameters or params."))
49+
end
50+
end
51+
52+
function Base.merge(a::QuantumTimeEvoParameters, b::NamedTuple)
53+
return QuantumTimeEvoParameters(a.expvals, a.progr, merge(a.params, b))
54+
end
55+
56+
struct QuantumTimeEvoProblem{PT,TT<:AbstractVector,DT<:AbstractVector}
57+
prob::PT
58+
times::TT
59+
dims::DT
60+
end
61+
862
@doc raw"""
963
struct TimeEvolutionSol
1064

0 commit comments

Comments
 (0)