Skip to content

Commit 98c26ea

Browse files
Working version of ssesolve
1 parent 841d02a commit 98c26ea

File tree

4 files changed

+118
-99
lines changed

4 files changed

+118
-99
lines changed

src/QuantumToolbox.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,20 @@ import SciMLBase:
3434
DiscreteCallback
3535
import StochasticDiffEq: StochasticDiffEqAlgorithm, SRA1
3636
import SciMLOperators:
37-
AbstractSciMLOperator, MatrixOperator, ScalarOperator, cache_operator, update_coefficients!, concretize, isconstant
37+
AbstractSciMLOperator,
38+
MatrixOperator,
39+
ScalarOperator,
40+
IdentityOperator,
41+
cache_operator,
42+
update_coefficients!,
43+
concretize,
44+
isconstant
3845
import LinearSolve: LinearProblem, SciMLLinearSolveAlgorithm, KrylovJL_MINRES, KrylovJL_GMRES
3946
import DiffEqBase: get_tstops
4047
import DiffEqCallbacks: PeriodicCallback, PresetTimeCallback, TerminateSteadyState
4148
import OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm
4249
import OrdinaryDiffEqTsit5: Tsit5
43-
import DiffEqNoiseProcess: RealWienerProcess
50+
import DiffEqNoiseProcess: RealWienerProcess!
4451

4552
# other dependencies (in alphabetical order)
4653
import ArrayInterface: allowed_getindex, allowed_setindex!

src/steadystate.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,16 @@ function steadystate(
236236
}
237237
ftype = _FType(ψ0)
238238
cb = TerminateSteadyState(abstol, reltol, _steadystate_ode_condition)
239-
sol = mesolve(H, ψ0, [ftype(0), ftype(tmax)], c_ops, progress_bar = Val(false), save_everystep=false, saveat=ftype[], callback = cb)
239+
sol = mesolve(
240+
H,
241+
ψ0,
242+
[ftype(0), ftype(tmax)],
243+
c_ops,
244+
progress_bar = Val(false),
245+
save_everystep = false,
246+
saveat = ftype[],
247+
callback = cb,
248+
)
240249

241250
ρss = sol.states[end]
242251
return ρss

src/time_evolution/ssesolve.jl

Lines changed: 95 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,48 @@
11
export ssesolveProblem, ssesolveEnsembleProblem, ssesolve
22

3-
#TODO: Check if works in GPU
4-
function _ssesolve_update_coefficients!(ψ, coefficients, sc_ops)
5-
_get_en = op -> real(dot(ψ, op, ψ)) #this is en/2: <Sn + Sn'>/2 = Re<Sn>
6-
@. coefficients[2:end-1] = _get_en(sc_ops) #coefficients of the OperatorSum: Σ Sn * en/2
7-
coefficients[end] = -sum(x -> x^2, coefficients[2:end-1]) / 2 #this last coefficient is -Σen^2/8
8-
return nothing
3+
#=
4+
struct DiffusionOperator
5+
6+
A struct to represent the diffusion operator. This is used to perform the diffusion process on N different Wiener processes.
7+
=#
8+
struct DiffusionOperator{T,OT<:Tuple{Vararg{AbstractSciMLOperator}}} <: AbstractSciMLOperator{T}
9+
ops::OT
10+
function DiffusionOperator(ops::OT) where {OT}
11+
T = mapreduce(eltype, promote_type, ops)
12+
return new{T,OT}(ops)
13+
end
914
end
1015

11-
function ssesolve_drift!(du, u, p, t)
12-
_ssesolve_update_coefficients!(u, p.K.coefficients, p.sc_ops)
13-
14-
mul!(du, p.K, u)
16+
@generated function update_coefficients!(L::DiffusionOperator, u, p, t)
17+
ops_types = L.parameters[2].parameters
18+
N = length(ops_types)
19+
quote
20+
Base.@nexprs $N i -> begin
21+
update_coefficients!(L.ops[i], u, p, t)
22+
end
1523

16-
return nothing
24+
nothing
25+
end
1726
end
1827

19-
function ssesolve_diffusion!(du, u, p, t)
20-
@inbounds en = @view(p.K.coefficients[2:end-1])
21-
22-
# du:(H,W). du_reshaped:(H*W,).
23-
# H:Hilbert space dimension, W: number of sc_ops
24-
du_reshaped = reshape(du, :)
25-
mul!(du_reshaped, p.D, u) #du[:,i] = D[i] * u
26-
27-
du .-= u .* reshape(en, 1, :) #du[:,i] -= en[i] * u
28+
@generated function LinearAlgebra.mul!(v::AbstractVecOrMat, L::DiffusionOperator, u::AbstractVecOrMat)
29+
ops_types = L.parameters[2].parameters
30+
N = length(ops_types)
31+
quote
32+
M = length(u)
33+
S = size(v)
34+
(S[1] == M && S[2] == $N) || throw(DimensionMismatch("The size of the output vector is incorrect."))
35+
Base.@nexprs $N i -> begin
36+
mul!(@view(v[:, i]), L.ops[i], u)
37+
end
38+
v
39+
end
40+
end
2841

29-
return nothing
42+
# TODO: Implement the three-argument dot function for SciMLOperators.jl
43+
# Currently, we are assuming a time-independent MatrixOperator
44+
function _ssesolve_update_coeff(u, p, t, op)
45+
return real(dot(u, op.A, u)) #this is en/2: <Sn + Sn'>/2 = Re<Sn>
3046
end
3147

3248
function _ssesolve_prob_func(prob, i, repeat)
@@ -37,27 +53,15 @@ function _ssesolve_prob_func(prob, i, repeat)
3753
traj_rng = typeof(global_rng)()
3854
seed!(traj_rng, seed)
3955

40-
noise = RealWienerProcess(
56+
noise = RealWienerProcess!(
4157
prob.tspan[1],
42-
zeros(length(internal_params.sc_ops)),
43-
zeros(length(internal_params.sc_ops)),
58+
zeros(internal_params.n_sc_ops),
59+
zeros(internal_params.n_sc_ops),
4460
save_everystep = false,
4561
rng = traj_rng,
4662
)
4763

48-
# noise_rate_prototype = similar(prob.u0, length(prob.u0), length(internal_params.sc_ops))
49-
50-
prm = merge(
51-
internal_params,
52-
(
53-
K = deepcopy(internal_params.K),
54-
D = deepcopy(internal_params.D),
55-
expvals = similar(internal_params.expvals),
56-
progr = ProgressBar(size(internal_params.expvals, 2), enable = false),
57-
),
58-
)
59-
60-
return remake(prob, p = prm, noise = noise, seed = seed)
64+
return remake(prob, noise = noise, seed = seed)
6165
end
6266

6367
# Standard output function
@@ -88,6 +92,11 @@ function _ssesolve_generate_statistics!(sol, i, states, expvals_all)
8892
return nothing
8993
end
9094

95+
_ScalarOperator_e(op, f = +) = ScalarOperator(one(eltype(op)), (a, u, p, t) -> f(_ssesolve_update_coeff(u, p, t, op)))
96+
97+
_ScalarOperator_e2_2(op, f = +) =
98+
ScalarOperator(one(eltype(op)), (a, u, p, t) -> f(_ssesolve_update_coeff(u, p, t, op)^2 / 2))
99+
91100
@doc raw"""
92101
ssesolveProblem(H::QuantumObject,
93102
ψ0::QuantumObject,
@@ -147,80 +156,77 @@ Above, `C_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener i
147156
- `prob`: The `SDEProblem` for the Stochastic Schrödinger time evolution of the system.
148157
"""
149158
function ssesolveProblem(
150-
H::QuantumObject{MT1,OperatorQuantumObject},
151-
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
159+
H::Union{AbstractQuantumObject{DT1,OperatorQuantumObject},Tuple},
160+
ψ0::QuantumObject{DT2,KetQuantumObject},
152161
tlist::AbstractVector,
153162
sc_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
154163
alg::StochasticDiffEqAlgorithm = SRA1(),
155164
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
156165
params::NamedTuple = NamedTuple(),
157166
rng::AbstractRNG = default_rng(),
158167
kwargs...,
159-
) where {MT1<:AbstractMatrix,T2}
160-
check_dims(H, ψ0)
161-
168+
) where {DT1,DT2}
162169
haskey(kwargs, :save_idxs) &&
163170
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))
164171

165172
sc_ops isa Nothing &&
166173
throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead."))
167174

168-
# !(H_t isa Nothing) && throw(ArgumentError("Time-dependent Hamiltonians are not currently supported in ssesolve."))
169-
170-
t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for StochasticDiffEq.jl
175+
tlist = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for StochasticDiffEq.jl
171176

172-
ϕ0 = get_data(ψ0)
177+
H_eff_evo = _mcsolve_make_Heff_QobjEvo(H, sc_ops)
178+
isoper(H_eff_evo) || throw(ArgumentError("The Hamiltonian must be an Operator."))
179+
check_dims(H_eff_evo, ψ0)
180+
dims = H_eff_evo.dims
173181

174-
H_eff = get_data(H - T2(0.5im) * mapreduce(op -> op' * op, +, sc_ops))
175-
sc_ops2 = get_data.(sc_ops)
182+
ψ0 = get_data(ψ0)
176183

177-
coefficients = [1.0, fill(0.0, length(sc_ops) + 1)...]
178-
operators = [-1im * H_eff, sc_ops2..., MT1(I(prod(H.dims)))]
179-
K = OperatorSum(coefficients, operators)
180-
_ssesolve_update_coefficients!(ϕ0, K.coefficients, sc_ops2)
181-
182-
D = reduce(vcat, sc_ops2)
184+
progr = ProgressBar(length(tlist), enable = false)
183185

184186
if e_ops isa Nothing
185-
expvals = Array{ComplexF64}(undef, 0, length(t_l))
186-
e_ops2 = MT1[]
187+
expvals = Array{ComplexF64}(undef, 0, length(tlist))
188+
e_ops_data = ()
187189
is_empty_e_ops = true
188190
else
189-
expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))
190-
e_ops2 = get_data.(e_ops)
191+
expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist))
192+
e_ops_data = get_data.(e_ops)
191193
is_empty_e_ops = isempty(e_ops)
192194
end
193195

196+
sc_ops_evo_data = Tuple(map(get_data QobjEvo, sc_ops))
197+
198+
# Here the coefficients depend on the state, so this is a non-linear operator, which should be implemented with FunctionOperator instead. However, the nonlinearity is only on the coefficients, and it should be safe.
199+
K_l = sum(
200+
op -> _ScalarOperator_e(op, +) * op + _ScalarOperator_e2_2(op, -) * IdentityOperator(prod(dims)),
201+
sc_ops_evo_data,
202+
)
203+
204+
K = -1im * get_data(H_eff_evo) + K_l
205+
206+
D_l = map(op -> op + _ScalarOperator_e(op, -) * IdentityOperator(prod(dims)), sc_ops_evo_data)
207+
D = DiffusionOperator(D_l)
208+
194209
p = (
195-
K = K,
196-
D = D,
197-
e_ops = e_ops2,
198-
sc_ops = sc_ops2,
210+
e_ops = e_ops_data,
199211
expvals = expvals,
200-
Hdims = H.dims,
201-
times = t_l,
212+
progr = progr,
213+
times = tlist,
214+
Hdims = dims,
202215
is_empty_e_ops = is_empty_e_ops,
216+
n_sc_ops = length(sc_ops),
203217
params...,
204218
)
205219

206-
saveat = is_empty_e_ops ? t_l : [t_l[end]]
220+
saveat = is_empty_e_ops ? tlist : [tlist[end]]
207221
default_values = (DEFAULT_SDE_SOLVER_OPTIONS..., saveat = saveat)
208222
kwargs2 = merge(default_values, kwargs)
209-
kwargs3 = _generate_sesolve_kwargs(e_ops, Val(false), t_l, kwargs2)
210-
211-
tspan = (t_l[1], t_l[end])
212-
noise = RealWienerProcess(t_l[1], zeros(length(sc_ops)), zeros(length(sc_ops)), save_everystep = false, rng = rng)
213-
noise_rate_prototype = similar(ϕ0, length(ϕ0), length(sc_ops))
214-
return SDEProblem{true}(
215-
ssesolve_drift!,
216-
ssesolve_diffusion!,
217-
ϕ0,
218-
tspan,
219-
p;
220-
noise_rate_prototype = noise_rate_prototype,
221-
noise = noise,
222-
kwargs3...,
223-
)
223+
kwargs3 = _generate_sesolve_kwargs(e_ops, Val(false), tlist, kwargs2)
224+
225+
tspan = (tlist[1], tlist[end])
226+
noise =
227+
RealWienerProcess!(tlist[1], zeros(length(sc_ops)), zeros(length(sc_ops)), save_everystep = false, rng = rng)
228+
noise_rate_prototype = similar(ψ0, length(ψ0), length(sc_ops))
229+
return SDEProblem{true}(K, D, ψ0, tspan, p; noise_rate_prototype = noise_rate_prototype, noise = noise, kwargs3...)
224230
end
225231

226232
@doc raw"""
@@ -290,8 +296,8 @@ Above, `C_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener i
290296
- `prob::EnsembleProblem with SDEProblem`: The Ensemble SDEProblem for the Stochastic Shrödinger time evolution.
291297
"""
292298
function ssesolveEnsembleProblem(
293-
H::QuantumObject{MT1,OperatorQuantumObject},
294-
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
299+
H::Union{AbstractQuantumObject{DT1,OperatorQuantumObject},Tuple},
300+
ψ0::QuantumObject{DT2,KetQuantumObject},
295301
tlist::AbstractVector,
296302
sc_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
297303
alg::StochasticDiffEqAlgorithm = SRA1(),
@@ -304,7 +310,7 @@ function ssesolveEnsembleProblem(
304310
output_func::Function = _ssesolve_dispatch_output_func(ensemble_method),
305311
progress_bar::Union{Val,Bool} = Val(true),
306312
kwargs...,
307-
) where {MT1<:AbstractMatrix,T2}
313+
) where {DT1,DT2}
308314
progr = ProgressBar(ntraj, enable = getVal(progress_bar))
309315
if ensemble_method isa EnsembleDistributed
310316
progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1))
@@ -331,7 +337,9 @@ function ssesolveEnsembleProblem(
331337
kwargs...,
332338
)
333339

334-
ensemble_prob = EnsembleProblem(prob_sse, prob_func = prob_func, output_func = output_func, safetycopy = false)
340+
# safetycopy is set to true because the K and D functions cannot be currently deepcopied.
341+
# the memory overhead shouldn't be too large, compared to the safetycopy=false case.
342+
ensemble_prob = EnsembleProblem(prob_sse, prob_func = prob_func, output_func = output_func, safetycopy = true)
335343

336344
return ensemble_prob
337345
catch e
@@ -413,8 +421,8 @@ Above, `C_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener i
413421
- `sol::TimeEvolutionSSESol`: The solution of the time evolution. See also [`TimeEvolutionSSESol`](@ref)
414422
"""
415423
function ssesolve(
416-
H::QuantumObject{MT1,OperatorQuantumObject},
417-
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
424+
H::Union{AbstractQuantumObject{DT1,OperatorQuantumObject},Tuple},
425+
ψ0::QuantumObject{DT2,KetQuantumObject},
418426
tlist::AbstractVector,
419427
sc_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
420428
alg::StochasticDiffEqAlgorithm = SRA1(),
@@ -427,13 +435,7 @@ function ssesolve(
427435
output_func::Function = _ssesolve_dispatch_output_func(ensemble_method),
428436
progress_bar::Union{Val,Bool} = Val(true),
429437
kwargs...,
430-
) where {MT1<:AbstractMatrix,T2}
431-
progr = ProgressBar(ntraj, enable = getVal(progress_bar))
432-
progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1))
433-
@async while take!(progr_channel)
434-
next!(progr)
435-
end
436-
438+
) where {DT1,DT2}
437439
ens_prob = ssesolveEnsembleProblem(
438440
H,
439441
ψ0,

test/core-test/time_evolution.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,17 +137,18 @@
137137
end
138138

139139
@testset "Type Inference ssesolve" begin
140+
c_ops_tuple = Tuple(c_ops) # To avoid type instability, we must have a Tuple instead of a Vector
140141
@inferred ssesolveEnsembleProblem(
141142
H,
142143
psi0,
143144
t_l,
144-
c_ops,
145+
c_ops_tuple,
145146
ntraj = 500,
146147
e_ops = e_ops,
147148
progress_bar = Val(false),
148149
)
149-
@inferred ssesolve(H, psi0, t_l, c_ops, ntraj = 500, e_ops = e_ops, progress_bar = Val(false))
150-
@inferred ssesolve(H, psi0, t_l, c_ops, ntraj = 500, progress_bar = Val(true))
150+
@inferred ssesolve(H, psi0, t_l, c_ops_tuple, ntraj = 500, e_ops = e_ops, progress_bar = Val(false))
151+
@inferred ssesolve(H, psi0, t_l, c_ops_tuple, ntraj = 500, progress_bar = Val(true))
151152
end
152153

153154
@testset "mcsolve and ssesolve reproducibility" begin

0 commit comments

Comments
 (0)