Skip to content

Commit 5583257

Browse files
Working sesolve case
1 parent 0c85039 commit 5583257

File tree

3 files changed

+29
-15
lines changed

3 files changed

+29
-15
lines changed

src/time_evolution/callback_helpers/callback_helpers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ Return the Callback that is responsible for saving the expectation values of the
139139
=#
140140
function _get_save_callback(sol::AbstractODESolution, method::Type{SF}) where {SF<:AbstractSaveFunc}
141141
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
142-
if hasproperty(kwargs, :callback)
142+
if hasproperty(kwargs, :callback) && !isnothing(kwargs.callback)
143143
return _get_save_callback(kwargs.callback, method)
144144
else
145145
return nothing

src/time_evolution/sesolve.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,13 @@ end
183183
(TBA)
184184
"""
185185
function sesolve_map(
186-
H::Union{AbstractQuantumObject{Operator},Tuple},
187-
ψ0::Vector{<:QuantumObject{Ket}},
186+
H::Union{QuantumObjectEvolution{Operator},Tuple},
187+
ψ0::AbstractVector{<:QuantumObject{Ket}},
188188
tlist::AbstractVector;
189189
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
190190
ensemblealg::EnsembleAlgorithm = EnsembleThreads(),
191191
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
192-
params = [NullParameters()],
192+
params::Tuple = (NullParameters(),),
193193
progress_bar::Union{Val,Bool} = Val(true),
194194
kwargs...,
195195
)
@@ -214,7 +214,13 @@ function sesolve_map(
214214
ens_prob = TimeEvolutionProblem(
215215
EnsembleProblem(
216216
prob.prob,
217-
prob_func = (prob, i, repeat) -> remake(prob, u0 = iter[i][1], p = iter[i][2:end]),
217+
prob_func = (prob, i, repeat) -> remake(
218+
prob,
219+
f = deepcopy(prob.f.f),
220+
u0 = iter[i][1],
221+
p = iter[i][2:end],
222+
callback = haskey(prob.kwargs, :callback) ? deepcopy(prob.kwargs[:callback]) : nothing,
223+
),
218224
safetycopy = false,
219225
),
220226
prob.times,

test/core-test/time_evolution.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -157,35 +157,43 @@ end
157157
e_ops = TESetup.e_ops
158158

159159
g = 0.01
160+
161+
ψ_0_e = tensor(fock(N, 0), basis(2, 0))
162+
ψ_1_g = tensor(fock(N, 1), basis(2, 1))
163+
164+
ψ0_list = [ψ_0_e, ψ_1_g]
160165
ωc_list = [1, 1.01, 1.02]
161166
ωq_list = [0.96, 0.97, 0.98, 0.99]
167+
162168
tlist = range(0, 20 * 2π / g, 1000)
163169

164-
ωc(p, t) = p[1]
165-
ωq(p, t) = p[2]
166-
H = g * (a' * σm + a * σm') + QobjEvo(a' * a, ωc) + QobjEvo(σz / 2, ωq)
170+
ωc_fun(p, t) = p[1]
171+
ωq_fun(p, t) = p[2]
172+
H = QobjEvo(a' * a, ωc_fun) + QobjEvo(σz / 2, ωq_fun) + g * (a' * σm + a * σm')
167173

168-
sols1 = sesolve_map(H, ψ0, tlist; e_ops = e_ops, params = [ωc_list, ωq_list])
169-
sols2 = sesolve_map(H, [ψ0, ψ0], tlist; e_ops = e_ops, params = [ωc_list, ωq_list], progress_bar = Val(false))
174+
sols1 = sesolve_map(H, ψ_0_e, tlist; e_ops = e_ops, params = (ωc_list, ωq_list))
175+
sols2 = sesolve_map(H, ψ0_list, tlist; e_ops = e_ops, params = (ωc_list, ωq_list), progress_bar = Val(false))
170176

171177
@test size(sols1) == (1, 3, 4)
172-
@test typeof(sols1) isa Array{<:TimeEvolutionSol}
178+
@test sols1 isa Array{<:TimeEvolutionSol}
173179
@test size(sols2) == (2, 3, 4)
174-
@test typeof(sols2) isa Array{<:TimeEvolutionSol}
180+
@test sols2 isa Array{<:TimeEvolutionSol}
175181
for (i, ωc) in enumerate(ωc_list)
176182
for (j, ωq) in enumerate(ωq_list)
177-
sol = sols1[1, i, j]
183+
sol_0_e = sols2[1, i, j]
184+
sol_1_g = sols2[2, i, j]
178185

179186
## Analytical solution for the expectation value of a' * a
180187
Ω_rabi = sqrt(g^2 + ((ωc - ωq) / 2)^2)
181188
amp_rabi = g^2 / Ω_rabi^2
182189

183-
@test sum(abs.(sol.expect[1, :] .- amp_rabi .* sin.(Ω_rabi * tlist) .^ 2)) / length(tlist) < 0.1
190+
@test sol_0_e.expect[1, :] amp_rabi .* sin.(Ω_rabi * tlist) .^ 2 atol = 1e-2
191+
@test sol_1_g.expect[1, :] 1 .- amp_rabi .* sin.(Ω_rabi * tlist) .^ 2 atol = 1e-2
184192
end
185193
end
186194

187195
@testset "Type Inference sesolve_map" begin
188-
@inferred sesolve_map(H, [ψ0, ψ0], tlist; e_ops = e_ops, params = [ωc_list, ωq_list], progress_bar = Val(false))
196+
@inferred sesolve_map(H, ψ0_list, tlist; e_ops = e_ops, params = (ωc_list, ωq_list), progress_bar = Val(false))
189197
end
190198
end
191199

0 commit comments

Comments
 (0)