Skip to content

Commit 641e213

Browse files
Add mesolve_map case
1 parent d0286e4 commit 641e213

File tree

3 files changed

+244
-16
lines changed

3 files changed

+244
-16
lines changed

src/time_evolution/mesolve.jl

Lines changed: 156 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,31 @@
1-
export mesolveProblem, mesolve
1+
export mesolveProblem, mesolve, mesolve_map
22

33
_mesolve_make_L_QobjEvo(H::Union{QuantumObject,Nothing}, c_ops) = QobjEvo(liouvillian(H, c_ops); type = SuperOperator())
44
_mesolve_make_L_QobjEvo(H::Union{QuantumObjectEvolution,Tuple}, c_ops) = liouvillian(QobjEvo(H), c_ops)
55
_mesolve_make_L_QobjEvo(H::Nothing, c_ops::Nothing) = throw(ArgumentError("Both H and
66
c_ops are Nothing. You are probably running the wrong function."))
77

8+
function _gen_mesolve_solution(sol, times, dimensions, isoperket::Val)
9+
if getVal(isoperket)
10+
ρt = map-> QuantumObject(ϕ, type = OperatorKet(), dims = dimensions), sol.u)
11+
else
12+
ρt = map-> QuantumObject(vec2mat(ϕ), type = Operator(), dims = dimensions), sol.u)
13+
end
14+
15+
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility
16+
17+
return TimeEvolutionSol(
18+
times,
19+
sol.t,
20+
ρt,
21+
_get_expvals(sol, SaveFuncMESolve),
22+
sol.retcode,
23+
sol.alg,
24+
kwargs.abstol,
25+
kwargs.reltol,
26+
)
27+
end
28+
829
@doc raw"""
930
mesolveProblem(
1031
H::Union{AbstractQuantumObject,Tuple},
@@ -207,23 +228,143 @@ end
207228
function mesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5(); kwargs...)
208229
sol = solve(prob.prob, alg; kwargs...)
209230

210-
# No type instabilities since `isoperket` is a Val, and so it is known at compile time
211-
if getVal(prob.kwargs.isoperket)
212-
ρt = map-> QuantumObject(ϕ, type = OperatorKet(), dims = prob.dimensions), sol.u)
213-
else
214-
ρt = map-> QuantumObject(vec2mat(ϕ), type = Operator(), dims = prob.dimensions), sol.u)
231+
return _gen_mesolve_solution(sol, prob.times, prob.dimensions, prob.kwargs.isoperket)
232+
end
233+
234+
@doc raw"""
235+
mesolve_map(
236+
H::Union{AbstractQuantumObject,Tuple},
237+
ψ0::Union{QuantumObject,AbstractVector{<:QuantumObject}},
238+
tlist::AbstractVector,
239+
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
240+
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
241+
ensemblealg::EnsembleAlgorithm = EnsembleThreads(),
242+
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
243+
params::Tuple = (NullParameters(),),
244+
progress_bar::Union{Val,Bool} = Val(true),
245+
kwargs...,
246+
)
247+
248+
Solve the master equation for multiple initial states and parameter sets using ensemble simulation.
249+
250+
This function computes the time evolution for all combinations (Cartesian product) of initial states and parameter sets, solving the Lindblad master equation (see [`mesolve`](@ref)):
251+
252+
```math
253+
\frac{\partial \hat{\rho}(t)}{\partial t} = -i[\hat{H}, \hat{\rho}(t)] + \sum_n \mathcal{D}(\hat{C}_n) [\hat{\rho}(t)]
254+
```
255+
256+
where
257+
258+
```math
259+
\mathcal{D}(\hat{C}_n) [\hat{\rho}(t)] = \hat{C}_n \hat{\rho}(t) \hat{C}_n^\dagger - \frac{1}{2} \hat{C}_n^\dagger \hat{C}_n \hat{\rho}(t) - \frac{1}{2} \hat{\rho}(t) \hat{C}_n^\dagger \hat{C}_n
260+
```
261+
262+
for each combination in the ensemble.
263+
264+
# Arguments
265+
266+
- `H`: Hamiltonian of the system ``\hat{H}``. It can be either a [`QuantumObject`](@ref), a [`QuantumObjectEvolution`](@ref), or a `Tuple` of operator-function pairs.
267+
- `ψ0`: Initial state(s) of the system. Can be a single [`QuantumObject`](@ref) or a `Vector` of initial states. It can be either a [`Ket`](@ref), [`Operator`](@ref) or [`OperatorKet`](@ref).
268+
- `tlist`: List of time points at which to save either the state or the expectation values of the system.
269+
- `c_ops`: List of collapse operators ``\{\hat{C}_n\}_n``. It can be either a `Vector` or a `Tuple`.
270+
- `alg`: The algorithm for the ODE solver. The default is `Tsit5()`.
271+
- `ensemblealg`: Ensemble algorithm to use for parallel computation. Default is `EnsembleThreads()`.
272+
- `e_ops`: List of operators for which to calculate expectation values. It can be either a `Vector` or a `Tuple`.
273+
- `params`: A `Tuple` of parameter sets. Each element should be an `AbstractVector` representing the sweep range for that parameter. The function will solve for all combinations of initial states and parameter sets.
274+
- `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
275+
- `kwargs`: The keyword arguments for the ODEProblem.
276+
277+
# Notes
278+
279+
- The function returns an array of solutions with dimensions matching the Cartesian product of initial states and parameter sets.
280+
- If `ψ0` is a vector of `m` states and `params = (p1, p2, ...)` where `p1` has length `n1`, `p2` has length `n2`, etc., the output will be of size `(m, n1, n2, ...)`.
281+
- If `H` is an [`Operator`](@ref), `ψ0` is a [`Ket`](@ref) and `c_ops` is `Nothing`, the function will call [`sesolve_map`](@ref) instead.
282+
- See [`mesolve`](@ref) for more details.
283+
284+
# Returns
285+
286+
- An array of [`TimeEvolutionSol`](@ref) objects with dimensions `(length(ψ0), length(params[1]), length(params[2]), ...)`.
287+
"""
288+
function mesolve_map(
289+
H::Union{AbstractQuantumObject{HOpType},Tuple},
290+
ψ0::AbstractVector{<:QuantumObject{StateOpType}},
291+
tlist::AbstractVector,
292+
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
293+
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
294+
ensemblealg::EnsembleAlgorithm = EnsembleThreads(),
295+
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
296+
params::Tuple = (NullParameters(),),
297+
progress_bar::Union{Val,Bool} = Val(true),
298+
kwargs...,
299+
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}}
300+
(isoper(H) && all(isket, ψ0) && isnothing(c_ops)) && return sesolve_map(
301+
H,
302+
ψ0,
303+
tlist;
304+
alg = alg,
305+
ensemblealg = ensemblealg,
306+
e_ops = e_ops,
307+
params = params,
308+
progress_bar = progress_bar,
309+
kwargs...,
310+
)
311+
312+
# mapping initial states and parameters
313+
# Convert to appropriate format based on state type
314+
ψ0_iter = map(ψ0) do state
315+
T = _complex_float_type(eltype(state))
316+
if isoperket(state)
317+
to_dense(T, copy(state.data))
318+
else
319+
to_dense(T, mat2vec(ket2dm(state).data))
320+
end
215321
end
322+
iter = collect(Iterators.product(ψ0_iter, params...))
323+
ntraj = length(iter)
216324

217-
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility
325+
# we disable the progress bar of the mesolveProblem because we use a global progress bar for all the trajectories
326+
prob = mesolveProblem(
327+
H,
328+
first(ψ0),
329+
tlist,
330+
c_ops;
331+
e_ops = e_ops,
332+
params = first(iter)[2:end],
333+
progress_bar = Val(false),
334+
kwargs...,
335+
)
218336

219-
return TimeEvolutionSol(
337+
# generate and solve ensemble problem
338+
_output_func = _ensemble_dispatch_output_func(ensemblealg, progress_bar, ntraj, _standard_output_func) # setup global progress bar
339+
ens_prob = TimeEvolutionProblem(
340+
EnsembleProblem(
341+
prob.prob,
342+
prob_func = (prob, i, repeat) -> remake(
343+
prob,
344+
f = deepcopy(prob.f.f),
345+
u0 = iter[i][1],
346+
p = iter[i][2:end],
347+
callback = haskey(prob.kwargs, :callback) ? deepcopy(prob.kwargs[:callback]) : nothing,
348+
),
349+
safetycopy = false,
350+
),
220351
prob.times,
221-
sol.t,
222-
ρt,
223-
_get_expvals(sol, SaveFuncMESolve),
224-
sol.retcode,
225-
sol.alg,
226-
kwargs.abstol,
227-
kwargs.reltol,
352+
prob.dimensions,
353+
(progr = _output_func[2], channel = _output_func[3], isoperket = prob.kwargs.isoperket),
354+
)
355+
sol = _ensemble_dispatch_solve(ens_prob, alg, ensemblealg, ntraj)
356+
357+
# handle solution and make it become an Array of TimeEvolutionSol
358+
return reshape(
359+
map(i -> _gen_mesolve_solution(sol[:, i], prob.times, prob.dimensions, prob.kwargs.isoperket), eachindex(sol)),
360+
size(iter),
228361
)
229362
end
363+
mesolve_map(
364+
H::Union{AbstractQuantumObject{HOpType},Tuple},
365+
ψ0::QuantumObject{StateOpType},
366+
tlist::AbstractVector,
367+
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
368+
kwargs...,
369+
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}} =
370+
mesolve_map(H, [ψ0], tlist, c_ops; kwargs...)

src/time_evolution/sesolve.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ function sesolve_map(
270270
sol = _ensemble_dispatch_solve(ens_prob, alg, ensemblealg, ntraj)
271271

272272
# handle solution and make it become an Array of TimeEvolutionSol
273-
return reshape(map(i -> _gen_sesolve_solution(sol[:, i], prob.times, prob.dimensions), eachindex(sol)), size(iter))
273+
sol_vec = map(i -> _gen_sesolve_solution(sol[:, i], prob.times, prob.dimensions), eachindex(sol))
274+
# sol_vec = _gen_sesolve_solution.(sol[:], Ref(prob.times), Ref(prob.dimensions))
275+
return reshape(sol_vec, size(iter))
274276
end
275277
sesolve_map(H::Union{AbstractQuantumObject{Operator},Tuple}, ψ0::QuantumObject{Ket}, tlist::AbstractVector; kwargs...) =
276278
sesolve_map(H, [ψ0], tlist; kwargs...)

test/core-test/time_evolution.jl

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ end
150150
@testitem "sesolve_map" setup=[TESetup] begin
151151

152152
# Get parameters from TESetup to simplify the code
153+
N = TESetup.N
153154
a = TESetup.a
154155
σz = TESetup.σz
155156
σm = TESetup.σm
@@ -296,6 +297,90 @@ end
296297
end
297298
end
298299

300+
@testitem "mesolve_map" setup=[TESetup] begin
301+
302+
# Get parameters from TESetup to simplify the code
303+
N = TESetup.N
304+
a = TESetup.a
305+
σz = TESetup.σz
306+
σm = TESetup.σm
307+
ψ0 = TESetup.ψ0
308+
c_ops = TESetup.c_ops
309+
e_ops = TESetup.e_ops
310+
γ = TESetup.γ
311+
nth = TESetup.nth
312+
313+
g = 0.01
314+
315+
ψ_0_e = tensor(fock(N, 0), basis(2, 0))
316+
ψ_1_g = tensor(fock(N, 1), basis(2, 1))
317+
318+
ψ0_list = [ψ_0_e, ψ_1_g]
319+
ωc_list = [1, 1.01, 1.02]
320+
ωq_list = [0.96, 0.97, 0.98, 0.99]
321+
322+
tlist = range(0, 10 / γ, 100)
323+
324+
ωc_fun(p, t) = p[1]
325+
ωq_fun(p, t) = p[2]
326+
H = QobjEvo(a' * a, ωc_fun) + QobjEvo(σz / 2, ωq_fun) + g * (a' * σm + a * σm')
327+
328+
# Test with single initial state
329+
sols1 = mesolve_map(H, ψ_0_e, tlist, c_ops; e_ops = e_ops, params = (ωc_list, ωq_list))
330+
# Test with multiple initial states
331+
sols2 = mesolve_map(H, ψ0_list, tlist, c_ops; e_ops = e_ops, params = (ωc_list, ωq_list), progress_bar = Val(false))
332+
333+
# Test redirect to sesolve_map when c_ops is nothing
334+
sols3 = mesolve_map(H, ψ0_list, tlist; e_ops = e_ops, params = (ωc_list, ωq_list), progress_bar = Val(false))
335+
336+
@test size(sols1) == (1, 3, 4)
337+
@test sols1 isa Array{<:TimeEvolutionSol}
338+
@test size(sols2) == (2, 3, 4)
339+
@test sols2 isa Array{<:TimeEvolutionSol}
340+
@test size(sols3) == (2, 3, 4)
341+
@test sols3 isa Array{<:TimeEvolutionSol}
342+
343+
# Verify that solutions make physical sense
344+
for (i, ωc) in enumerate(ωc_list)
345+
for (j, ωq) in enumerate(ωq_list)
346+
sol_0_e = sols2[1, i, j]
347+
sol_1_g = sols2[2, i, j]
348+
349+
# Check that expectation values are bounded and physical (take real part for physical observables)
350+
@test all(x -> real(x) >= -1e-4, sol_0_e.expect[1, :]) # a'a should be non-negative (with small tolerance)
351+
@test all(x -> real(x) >= -1e-4, sol_1_g.expect[1, :])
352+
end
353+
end
354+
355+
# Test with OperatorKet input
356+
ρ0 = operator_to_vector(ket2dm(ψ_0_e))
357+
ρ0_list = [operator_to_vector(ket2dm(ψ_0_e)), operator_to_vector(ket2dm(ψ_1_g))]
358+
sols4 = mesolve_map(H, ρ0_list, tlist, c_ops; e_ops = e_ops, params = (ωc_list, ωq_list), progress_bar = Val(false))
359+
360+
@test size(sols4) == (2, 3, 4)
361+
@test all(isoperket.(getfield.(sols4, :states) .|> first))
362+
363+
# Test with Operator input (density matrix)
364+
dm0_list = [ket2dm(ψ_0_e), ket2dm(ψ_1_g)]
365+
sols5 =
366+
mesolve_map(H, dm0_list, tlist, c_ops; e_ops = e_ops, params = (ωc_list, ωq_list), progress_bar = Val(false))
367+
368+
@test size(sols5) == (2, 3, 4)
369+
@test sols5 isa Array{<:TimeEvolutionSol}
370+
371+
@testset "Type Inference mesolve_map" begin
372+
@inferred mesolve_map(
373+
H,
374+
ψ0_list,
375+
tlist,
376+
c_ops;
377+
e_ops = e_ops,
378+
params = (ωc_list, ωq_list),
379+
progress_bar = Val(false),
380+
)
381+
end
382+
end
383+
299384
@testitem "mcsolve" setup=[TESetup] begin
300385
using SciMLOperators
301386
using Statistics

0 commit comments

Comments
 (0)