Skip to content

Commit cdc63bd

Browse files
[no ci] Add state normalization during evolution in ssesolve
1 parent b63c19b commit cdc63bd

File tree

5 files changed

+73
-49
lines changed

5 files changed

+73
-49
lines changed

src/QuantumToolbox.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ include("time_evolution/time_evolution.jl")
9999
include("time_evolution/callback_helpers/sesolve_callback_helpers.jl")
100100
include("time_evolution/callback_helpers/mesolve_callback_helpers.jl")
101101
include("time_evolution/callback_helpers/mcsolve_callback_helpers.jl")
102+
include("time_evolution/callback_helpers/ssesolve_callback_helpers.jl")
102103
include("time_evolution/callback_helpers/callback_helpers.jl")
103104
include("time_evolution/mesolve.jl")
104105
include("time_evolution/lr_mesolve.jl")

src/time_evolution/callback_helpers/callback_helpers.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,26 @@ end
3232

3333
_get_e_ops_data(e_ops, ::Type{SaveFuncSESolve}) = get_data.(e_ops)
3434
_get_e_ops_data(e_ops, ::Type{SaveFuncMESolve}) = [_generate_mesolve_e_op(op) for op in e_ops] # Broadcasting generates type instabilities on Julia v1.10
35+
_get_e_ops_data(e_ops, ::Type{SaveFuncSSESolve}) = get_data.(e_ops)
3536

3637
_generate_mesolve_e_op(op) = mat2vec(adjoint(get_data(op)))
3738

39+
#=
40+
This function add the normalization callback to the kwargs. It is needed to stabilize the integration when using the ssesolve method.
41+
=#
42+
function _ssesolve_add_normalize_cb(kwargs)
43+
_condition = (u, t, integrator) -> true
44+
_affect! = (integrator) -> normalize!(integrator.u)
45+
cb = DiscreteCallback(_condition, _affect!; save_positions = (false, false))
46+
# return merge(kwargs, (callback = CallbackSet(kwargs[:callback], cb),))
47+
48+
cb_set = haskey(kwargs, :callback) ? CallbackSet(kwargs[:callback], cb) : cb
49+
50+
kwargs2 = merge(kwargs, (callback = cb_set,))
51+
52+
return kwargs2
53+
end
54+
3855
##
3956

4057
# When e_ops is Nothing. Common for both mesolve and sesolve
@@ -80,10 +97,10 @@ function _se_me_sse_get_save_callback(cb::CallbackSet)
8097
return nothing
8198
end
8299
end
83-
_se_me_sse_get_save_callback(cb::DiscreteCallback) =
84-
if (cb.affect! isa SaveFuncSESolve) || (cb.affect! isa SaveFuncMESolve)
100+
function _se_me_sse_get_save_callback(cb::DiscreteCallback)
101+
if typeof(cb.affect!) <: Union{SaveFuncSESolve,SaveFuncMESolve,SaveFuncSSESolve}
85102
return cb
86-
else
87-
return nothing
88103
end
104+
return nothing
105+
end
89106
_se_me_sse_get_save_callback(cb::ContinuousCallback) = nothing
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#=
2+
Helper functions for the ssesolve callbacks. Equal to the sesolve case, but with an additional normalization before saving the expectation values.
3+
=#
4+
5+
struct SaveFuncSSESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing,AbstractMatrix}}
6+
e_ops::TE
7+
progr::PT
8+
iter::IT
9+
expvals::TEXPV
10+
end
11+
12+
(f::SaveFuncSSESolve)(integrator) = _save_func_ssesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals)
13+
(f::SaveFuncSSESolve{Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both mesolve and sesolve
14+
15+
##
16+
17+
# When e_ops is a list of operators
18+
function _save_func_ssesolve(integrator, e_ops, progr, iter, expvals)
19+
ψ = normalize!(integrator.u)
20+
_expect = op -> dot(ψ, op, ψ)
21+
@. expvals[:, iter[]] = _expect(e_ops)
22+
iter[] += 1
23+
24+
return _save_func(integrator, progr)
25+
end

src/time_evolution/ssesolve.jl

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@ end
3535
Base.@nexprs $N i -> begin
3636
mul!(@view(v[:, i]), L.ops[i], u)
3737
end
38-
v
38+
return v
3939
end
4040
end
4141

4242
# TODO: Implement the three-argument dot function for SciMLOperators.jl
4343
# Currently, we are assuming a time-independent MatrixOperator
4444
function _ssesolve_update_coeff(u, p, t, op)
45+
normalize!(u)
4546
return real(dot(u, op.A, u)) #this is en/2: <Sn + Sn'>/2 = Re<Sn>
4647
end
4748

@@ -104,23 +105,23 @@ _ScalarOperator_e2_2(op, f = +) =
104105
Generate the SDEProblem for the Stochastic Schrödinger time evolution of a quantum system. This is defined by the following stochastic differential equation:
105106
106107
```math
107-
d|\psi(t)\rangle = -i K |\psi(t)\rangle dt + \sum_n M_n |\psi(t)\rangle dW_n(t)
108+
d|\psi(t)\rangle = -i \hat{K} |\psi(t)\rangle dt + \sum_n \hat{M}_n |\psi(t)\rangle dW_n(t)
108109
```
109110
110111
where
111112
112113
```math
113-
K = \hat{H} + i \sum_n \left(\frac{e_j} C_n - \frac{1}{2} \sum_{j} C_n^\dagger C_n - \frac{e_j^2}{8}\right),
114+
\hat{K} = \hat{H} + i \sum_n \left(\frac{e_n}{2} \hat{C}_n - \frac{1}{2} \hat{C}_n^\dagger \hat{C}_n - \frac{e_n^2}{8}\right),
114115
```
115116
```math
116-
M_n = C_n - \frac{e_n}{2},
117+
\hat{M}_n = \hat{C}_n - \frac{e_n}{2},
117118
```
118119
and
119120
```math
120-
e_n = \langle C_n + C_n^\dagger \rangle.
121+
e_n = \langle \hat{C}_n + \hat{C}_n^\dagger \rangle.
121122
```
122123
123-
Above, `C_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener increment associated to `C_n`.
124+
Above, `\hat{C}_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener increment associated to `\hat{C}_n`.
124125
125126
# Arguments
126127
@@ -193,13 +194,14 @@ function ssesolveProblem(
193194
saveat = is_empty_e_ops ? tlist : [tlist[end]]
194195
default_values = (DEFAULT_SDE_SOLVER_OPTIONS..., saveat = saveat)
195196
kwargs2 = merge(default_values, kwargs)
196-
kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncSESolve)
197+
kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncSSESolve)
198+
kwargs4 = _ssesolve_add_normalize_cb(kwargs3)
197199

198200
tspan = (tlist[1], tlist[end])
199201
noise =
200202
RealWienerProcess!(tlist[1], zeros(length(sc_ops)), zeros(length(sc_ops)), save_everystep = false, rng = rng)
201203
noise_rate_prototype = similar(ψ0, length(ψ0), length(sc_ops))
202-
return SDEProblem{true}(K, D, ψ0, tspan, p; noise_rate_prototype = noise_rate_prototype, noise = noise, kwargs3...)
204+
return SDEProblem{true}(K, D, ψ0, tspan, p; noise_rate_prototype = noise_rate_prototype, noise = noise, kwargs4...)
203205
end
204206

205207
@doc raw"""
@@ -222,23 +224,23 @@ end
222224
Generate the SDE EnsembleProblem for the Stochastic Schrödinger time evolution of a quantum system. This is defined by the following stochastic differential equation:
223225
224226
```math
225-
d|\psi(t)\rangle = -i K |\psi(t)\rangle dt + \sum_n M_n |\psi(t)\rangle dW_n(t)
227+
d|\psi(t)\rangle = -i \hat{K} |\psi(t)\rangle dt + \sum_n \hat{M}_n |\psi(t)\rangle dW_n(t)
226228
```
227229
228230
where
229231
230232
```math
231-
K = \hat{H} + i \sum_n \left(\frac{e_j} C_n - \frac{1}{2} \sum_{j} C_n^\dagger C_n - \frac{e_j^2}{8}\right),
233+
\hat{K} = \hat{H} + i \sum_n \left(\frac{e_n}{2} \hat{C}_n - \frac{1}{2} \hat{C}_n^\dagger \hat{C}_n - \frac{e_n^2}{8}\right),
232234
```
233235
```math
234-
M_n = C_n - \frac{e_n}{2},
236+
\hat{M}_n = \hat{C}_n - \frac{e_n}{2},
235237
```
236238
and
237239
```math
238-
e_n = \langle C_n + C_n^\dagger \rangle.
240+
e_n = \langle \hat{C}_n + \hat{C}_n^\dagger \rangle.
239241
```
240242
241-
Above, `C_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener increment associated to `C_n`.
243+
Above, `\hat{C}_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener increment associated to `\hat{C}_n`.
242244
243245
# Arguments
244246
@@ -345,23 +347,23 @@ Stochastic Schrödinger equation evolution of a quantum system given the system
345347
The stochastic evolution of the state ``|\psi(t)\rangle`` is defined by:
346348
347349
```math
348-
d|\psi(t)\rangle = -i K |\psi(t)\rangle dt + \sum_n M_n |\psi(t)\rangle dW_n(t)
350+
d|\psi(t)\rangle = -i \hat{K} |\psi(t)\rangle dt + \sum_n \hat{M}_n |\psi(t)\rangle dW_n(t)
349351
```
350352
351353
where
352354
353355
```math
354-
K = \hat{H} + i \sum_n \left(\frac{e_j} C_n - \frac{1}{2} \sum_{j} C_n^\dagger C_n - \frac{e_j^2}{8}\right),
356+
\hat{K} = \hat{H} + i \sum_n \left(\frac{e_n}{2} \hat{C}_n - \frac{1}{2} \hat{C}_n^\dagger \hat{C}_n - \frac{e_n^2}{8}\right),
355357
```
356358
```math
357-
M_n = C_n - \frac{e_n}{2},
359+
\hat{M}_n = \hat{C}_n - \frac{e_n}{2},
358360
```
359361
and
360362
```math
361-
e_n = \langle C_n + C_n^\dagger \rangle.
363+
e_n = \langle \hat{C}_n + \hat{C}_n^\dagger \rangle.
362364
```
363365
364-
Above, `C_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener increment associated to `C_n`.
366+
Above, `\hat{C}_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener increment associated to `\hat{C}_n`.
365367
366368
367369
# Arguments

test/core-test/time_evolution.jl

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
e_ops = [a' * a, σz]
1919
c_ops = [sqrt* (1 + nth)) * a, sqrt* nth) * a', sqrt* (1 + nth)) * σm, sqrt* nth) * σm']
2020

21+
ψ0_int = Qobj(round.(Int, ψ0.data), dims = ψ0.dims) # Used for testing the type inference
22+
2123
@testset "sesolve" begin
2224
tlist = range(0, 20 * 2π / g, 1000)
2325

@@ -83,7 +85,7 @@
8385
@testset "Type Inference sesolve" begin
8486
@inferred sesolveProblem(H, ψ0, tlist, progress_bar = Val(false))
8587
@inferred sesolveProblem(H, ψ0, [0, 10], progress_bar = Val(false))
86-
@inferred sesolveProblem(H, Qobj(zeros(Int64, N * 2); dims = (N, 2)), tlist, progress_bar = Val(false))
88+
@inferred sesolveProblem(H, ψ0_int, tlist, progress_bar = Val(false))
8789
@inferred sesolve(H, ψ0, tlist, e_ops = e_ops, progress_bar = Val(false))
8890
@inferred sesolve(H, ψ0, tlist, progress_bar = Val(false))
8991
@inferred sesolve(H, ψ0, tlist, e_ops = e_ops, saveat = tlist, progress_bar = Val(false))
@@ -367,14 +369,7 @@
367369
ad_t = QobjEvo(a', coef)
368370
@inferred mesolveProblem(H, ψ0, tlist, c_ops, e_ops = e_ops, progress_bar = Val(false))
369371
@inferred mesolveProblem(H, ψ0, [0, 10], c_ops, e_ops = e_ops, progress_bar = Val(false))
370-
@inferred mesolveProblem(
371-
H,
372-
tensor(Qobj(zeros(Int64, N)), Qobj([0, 1])),
373-
tlist,
374-
c_ops,
375-
e_ops = e_ops,
376-
progress_bar = Val(false),
377-
)
372+
@inferred mesolveProblem(H, ψ0_int, tlist, c_ops, e_ops = e_ops, progress_bar = Val(false))
378373
@inferred mesolve(H, ψ0, tlist, c_ops, e_ops = e_ops, progress_bar = Val(false))
379374
@inferred mesolve(H, ψ0, tlist, c_ops, progress_bar = Val(false))
380375
@inferred mesolve(H, ψ0, tlist, c_ops, e_ops = e_ops, saveat = tlist, progress_bar = Val(false))
@@ -398,15 +393,7 @@
398393
@inferred mcsolve(H, ψ0, tlist, c_ops, ntraj = 5, e_ops = e_ops, progress_bar = Val(false), rng = rng)
399394
@inferred mcsolve(H, ψ0, tlist, c_ops, ntraj = 5, progress_bar = Val(true), rng = rng)
400395
@inferred mcsolve(H, ψ0, [0, 10], c_ops, ntraj = 5, progress_bar = Val(false), rng = rng)
401-
@inferred mcsolve(
402-
H,
403-
tensor(Qobj(zeros(Int64, N)), Qobj([0, 1])),
404-
tlist,
405-
c_ops,
406-
ntraj = 5,
407-
progress_bar = Val(false),
408-
rng = rng,
409-
)
396+
@inferred mcsolve(H, ψ0_int, tlist, c_ops, ntraj = 5, progress_bar = Val(false), rng = rng)
410397
@inferred mcsolve(
411398
H,
412399
ψ0,
@@ -454,15 +441,7 @@
454441
)
455442
@inferred ssesolve(H, ψ0, tlist, c_ops_tuple, ntraj = 5, progress_bar = Val(true), rng = rng)
456443
@inferred ssesolve(H, ψ0, [0, 10], c_ops_tuple, ntraj = 5, progress_bar = Val(false), rng = rng)
457-
@inferred ssesolve(
458-
H,
459-
tensor(Qobj(zeros(Int64, N)), Qobj([0, 1])),
460-
tlist,
461-
c_ops_tuple,
462-
ntraj = 5,
463-
progress_bar = Val(false),
464-
rng = rng,
465-
)
444+
@inferred ssesolve(H, ψ0_int, tlist, c_ops_tuple, ntraj = 5, progress_bar = Val(false), rng = rng)
466445
@inferred ssesolve(
467446
H,
468447
ψ0,

0 commit comments

Comments
 (0)