Skip to content

Commit 51c5d24

Browse files
ytdHuangalbertomercurio
authored andcommitted
add sesolve_map
1 parent 7a2a5fc commit 51c5d24

File tree

3 files changed

+105
-5
lines changed

3 files changed

+105
-5
lines changed

src/time_evolution/sesolve.jl

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export sesolveProblem, sesolve
1+
export sesolveProblem, sesolve, sesolve_map
22

33
_sesolve_make_U_QobjEvo(H) = -1im * QuantumObjectEvolution(H, type = Operator())
44

@@ -157,12 +157,16 @@ end
157157
function sesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5(); kwargs...)
158158
sol = solve(prob.prob, alg; kwargs...)
159159

160-
ψt = map-> QuantumObject(ϕ, type = Ket(), dims = prob.dimensions), sol.u)
160+
return _gen_sesolve_solution(sol, prob.times, prob.dimensions)
161+
end
162+
163+
function _gen_sesolve_solution(sol, times, dimensions)
164+
ψt = map-> QuantumObject(ϕ, type = Ket(), dims = dimensions), sol.u)
161165

162166
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility
163167

164168
return TimeEvolutionSol(
165-
prob.times,
169+
times,
166170
sol.t,
167171
ψt,
168172
_get_expvals(sol, SaveFuncSESolve),
@@ -172,3 +176,55 @@ function sesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit
172176
kwargs.reltol,
173177
)
174178
end
179+
180+
@doc raw"""
181+
sesolve_map
182+
183+
(TBA)
184+
"""
185+
function sesolve_map(
186+
H::Union{AbstractQuantumObject{Operator},Tuple},
187+
ψ0::Vector{<:QuantumObject{Ket}},
188+
tlist::AbstractVector;
189+
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
190+
ensemblealg::EnsembleAlgorithm = EnsembleThreads(),
191+
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
192+
params = [NullParameters()],
193+
progress_bar::Union{Val,Bool} = Val(true),
194+
kwargs...,
195+
)
196+
# mapping initial states and parameters
197+
ψ0_iter = map(get_data, ψ0)
198+
iter = collect(Iterators.product(ψ0_iter, params...))
199+
ntraj = length(iter)
200+
201+
# we disable the progress bar of the sesolveProblem because we use a global progress bar for all the trajectories
202+
prob = sesolveProblem(
203+
H,
204+
first(ψ0),
205+
tlist;
206+
e_ops = e_ops,
207+
params = first(iter)[2:end],
208+
progress_bar = Val(false),
209+
kwargs...,
210+
)
211+
212+
# generate and solve ensemble problem
213+
_output_func = _ensemble_dispatch_output_func(ensemblealg, progress_bar, ntraj) # setup global progress bar
214+
ens_prob = TimeEvolutionProblem(
215+
EnsembleProblem(
216+
prob.prob,
217+
prob_func = (prob, i, repeat) -> remake(prob, u0 = iter[i][1], p = iter[i][2:end]),
218+
safetycopy = false,
219+
),
220+
prob.times,
221+
prob.dimensions,
222+
(progr = _output_func[2], channel = _output_func[3]),
223+
)
224+
sol = _ensemble_dispatch_solve(ens_prob, alg, ensemblealg, ntraj)
225+
226+
# handle solution and make it become an Array of TimeEvolutionSol
227+
return reshape(map(i -> _gen_sesolve_solution(sol[:, i], prob.times, prob.dimensions), eachindex(sol)), size(iter))
228+
end
229+
sesolve_map(H::Union{AbstractQuantumObject{Operator},Tuple}, ψ0::QuantumObject{Ket}, tlist::AbstractVector; kwargs...) =
230+
sesolve_map(H, [ψ0], tlist; kwargs...)

src/time_evolution/time_evolution.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,12 +358,14 @@ This is very useful especially for dispatching which method to use to update the
358358
=#
359359

360360
# Output function with progress bar update
361+
_ensemble_output_func_progress(sol, i, progr, ::Nothing) = next!(progr)
361362
function _ensemble_output_func_progress(sol, i, progr, output_func)
362363
next!(progr)
363364
return output_func(sol, i)
364365
end
365366

366367
# Output function with distributed channel update for progress bar
368+
_ensemble_output_func_distributed(sol, i, channel, ::Nothing) = put!(channel, true)
367369
function _ensemble_output_func_distributed(sol, i, channel, output_func)
368370
put!(channel, true)
369371
return output_func(sol, i)
@@ -373,7 +375,7 @@ function _ensemble_dispatch_output_func(
373375
::ET,
374376
progress_bar,
375377
ntraj,
376-
output_func,
378+
output_func = nothing,
377379
) where {ET<:Union{EnsembleSerial,EnsembleThreads}}
378380
if getVal(progress_bar)
379381
progr = ProgressBar(ntraj, enable = getVal(progress_bar))
@@ -387,7 +389,7 @@ function _ensemble_dispatch_output_func(
387389
::ET,
388390
progress_bar,
389391
ntraj,
390-
output_func,
392+
output_func = nothing,
391393
) where {ET<:Union{EnsembleSplitThreads,EnsembleDistributed}}
392394
if getVal(progress_bar)
393395
progr = ProgressBar(ntraj, enable = getVal(progress_bar))

test/core-test/time_evolution.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,48 @@ end
147147
end
148148
end
149149

150+
@testitem "sesolve_map" setup=[TESetup] begin
151+
152+
# Get parameters from TESetup to simplify the code
153+
a = TESetup.a
154+
σz = TESetup.σz
155+
σm = TESetup.σm
156+
ψ0 = TESetup.ψ0
157+
e_ops = TESetup.e_ops
158+
159+
g = 0.01
160+
ωc_list = [1, 1.01, 1.02]
161+
ωq_list = [0.96, 0.97, 0.98, 0.99]
162+
tlist = range(0, 20 * 2π / g, 1000)
163+
164+
ωc(p, t) = p[1]
165+
ωq(p, t) = p[2]
166+
H = g * (a' * σm + a * σm') + QobjEvo(a' * a, ωc) + QobjEvo(σz / 2, ωq)
167+
168+
sols1 = sesolve_map(H, ψ0, tlist; e_ops = e_ops, params = [ωc_list, ωq_list])
169+
sols2 = sesolve_map(H, [ψ0, ψ0], tlist; e_ops = e_ops, params = [ωc_list, ωq_list], progress_bar = Val(false))
170+
171+
@test size(sols1) == (1, 3, 4)
172+
@test typeof(sols1) isa Array{<:TimeEvolutionSol}
173+
@test size(sols2) == (2, 3, 4)
174+
@test typeof(sols2) isa Array{<:TimeEvolutionSol}
175+
for (i, ωc) in enumerate(ωc_list)
176+
for (j, ωq) in enumerate(ωq_list)
177+
sol = sols1[1, i, j]
178+
179+
## Analytical solution for the expectation value of a' * a
180+
Ω_rabi = sqrt(g^2 + ((ωc - ωq) / 2)^2)
181+
amp_rabi = g^2 / Ω_rabi^2
182+
183+
@test sum(abs.(sol.expect[1, :] .- amp_rabi .* sin.(Ω_rabi * tlist) .^ 2)) / length(tlist) < 0.1
184+
end
185+
end
186+
187+
@testset "Type Inference sesolve_map" begin
188+
@inferred sesolve_map(H, [ψ0, ψ0], tlist; e_ops = e_ops, params = [ωc_list, ωq_list], progress_bar = Val(false))
189+
end
190+
end
191+
150192
@testitem "mesolve" setup=[TESetup] begin
151193
using SciMLOperators
152194

0 commit comments

Comments
 (0)