Skip to content

Commit 3dcec43

Browse files
Rebase commits
1 parent db80e1b commit 3dcec43

File tree

8 files changed

+117
-94
lines changed

8 files changed

+117
-94
lines changed

src/QuantumToolbox.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import SciMLBase:
2121
reinit!,
2222
remake,
2323
u_modified!,
24+
ODEFunction,
2425
ODEProblem,
2526
SDEProblem,
2627
EnsembleProblem,

src/qobj/quantum_object_evo.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ struct QuantumObjectEvolution{DT<:AbstractSciMLOperator,ObjType<:QuantumObjectTy
9191
end
9292

9393
# Make the QuantumObjectEvolution, with the option to pre-multiply by a scalar
94-
function QuantumObjectEvolution(op_func_list::Tuple, α = true)
94+
function QuantumObjectEvolution(op_func_list::Tuple, α::Union{Nothing,Number} = nothing)
9595
op, data = _generate_data(op_func_list, α)
9696
dims = op.dims
9797
type = op.type
@@ -103,8 +103,15 @@ function QuantumObjectEvolution(op_func_list::Tuple, α = true)
103103
return QuantumObjectEvolution(data, type, dims)
104104
end
105105

106-
QuantumObjectEvolution(op::QuantumObject, α = true) =
107-
QuantumObjectEvolution(MatrixOperator* op.data), op.type, op.dims)
106+
QuantumObjectEvolution(op::QuantumObject, α::Union{Nothing,Number} = nothing) =
107+
QuantumObjectEvolution(_make_SciMLOperator(op, α), op.type, op.dims)
108+
109+
function QuantumObjectEvolution(op::QuantumObjectEvolution, α::Union{Nothing,Number} = nothing)
110+
if α isa Nothing
111+
return QuantumObjectEvolution(op.data, op.type, op.dims)
112+
end
113+
return QuantumObjectEvolution* op.data, op.type, op.dims)
114+
end
108115

109116
@generated function _generate_data(op_func_list::Tuple, α)
110117
op_func_list_types = op_func_list.parameters
@@ -158,10 +165,18 @@ end
158165
function _make_SciMLOperator(op_func::Tuple, α)
159166
T = eltype(op_func[1])
160167
update_func = (a, u, p, t) -> op_func[2](p, t)
168+
if α isa Nothing
169+
return ScalarOperator(zero(T), update_func) * MatrixOperator(op_func[1].data)
170+
end
161171
return ScalarOperator(zero(T), update_func) * MatrixOperator* op_func[1].data)
162172
end
163173

164-
_make_SciMLOperator(op::QuantumObject, α) = MatrixOperator* op.data)
174+
function _make_SciMLOperator(op::QuantumObject, α)
175+
if α isa Nothing
176+
return MatrixOperator(op.data)
177+
end
178+
return MatrixOperator* op.data)
179+
end
165180

166181
function (QO::QuantumObjectEvolution)(p, t)
167182
# We put 0 in the place of `u` because the time-dependence doesn't depend on the state

src/qobj/synonyms.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@ Note that this functions is same as `QuantumObject(A; type=type, dims=dims)`.
1818
Qobj(A; kwargs...) = QuantumObject(A; kwargs...)
1919

2020
@doc raw"""
21-
QobjEvo(op_func_list::Union{Tuple,QuantumObject}, α::Real=true)
21+
QobjEvo(op_func_list::Union{Tuple,AbstractQuantumObject}, α::Union{Nothing,Number}=nothing)
2222
2323
Generate [`QuantumObjectEvolution`](@ref)
2424
2525
Note that this functions is same as `QuantumObjectEvolution(op_func_list)`. If `α` is provided, all the operators in `op_func_list` will be pre-multiplied by `α`.
2626
"""
27-
QobjEvo(op_func_list::Union{Tuple,QuantumObject}, α = true) = QuantumObjectEvolution(op_func_list, α)
27+
QobjEvo(op_func_list::Union{Tuple,AbstractQuantumObject}, α::Union{Nothing,Number} = nothing) =
28+
QuantumObjectEvolution(op_func_list, α)
2829

2930
@doc raw"""
3031
shape(A::QuantumObject)

src/time_evolution/mcsolve.jl

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ function _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, j
117117
return jump_which[i] = sol_i.prob.p.jump_which
118118
end
119119

120+
_mcsolve_make_Heff_QobjEvo(H::QuantumObject, c_ops) = QobjEvo(H - 1im * mapreduce(op -> op' * op, +, c_ops) / 2)
121+
_mcsolve_make_Heff_QobjEvo(H::Tuple, c_ops) = QobjEvo((H..., -1im * mapreduce(op -> op' * op, +, c_ops) / 2))
122+
_mcsolve_make_Heff_QobjEvo(H::QuantumObjectEvolution, c_ops) =
123+
H + QobjEvo(mapreduce(op -> op' * op, +, c_ops), -1im / 2)
124+
120125
@doc raw"""
121126
mcsolveProblem(H::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject},
122127
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
@@ -191,41 +196,38 @@ If the environmental measurements register a quantum jump, the wave function und
191196
- `prob::ODEProblem`: The ODEProblem for the Monte Carlo wave function time evolution.
192197
"""
193198
function mcsolveProblem(
194-
H::QuantumObject{MT1,OperatorQuantumObject},
195-
ψ0::QuantumObject{<:AbstractArray,KetQuantumObject},
199+
H::Union{AbstractQuantumObject{DT1,OperatorQuantumObject},Tuple},
200+
ψ0::QuantumObject{DT2,KetQuantumObject},
196201
tlist::AbstractVector,
197202
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
198203
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
199204
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
200-
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
201205
params::NamedTuple = NamedTuple(),
202206
rng::AbstractRNG = default_rng(),
203207
jump_callback::TJC = ContinuousLindbladJumpCallback(),
204208
kwargs...,
205-
) where {MT1<:AbstractMatrix,TJC<:LindbladJumpCallbackType}
206-
check_dims(H, ψ0)
207-
209+
) where {DT1,DT2,TJC<:LindbladJumpCallbackType}
208210
haskey(kwargs, :save_idxs) &&
209211
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))
210212

211213
c_ops isa Nothing &&
212214
throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead."))
213215

214-
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
216+
tlist = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
215217

216-
H_eff = H - 1im * mapreduce(op -> op' * op, +, c_ops) / 2
218+
H_eff_evo = _mcsolve_make_Heff_QobjEvo(H, c_ops)
217219

218220
if e_ops isa Nothing
219-
expvals = Array{ComplexF64}(undef, 0, length(t_l))
221+
expvals = Array{ComplexF64}(undef, 0, length(tlist))
220222
is_empty_e_ops_mc = true
221-
e_ops2 = MT1[]
223+
e_ops_data = ()
222224
else
223-
expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))
224-
e_ops2 = get_data.(e_ops)
225+
expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist))
226+
e_ops_data = get_data.(e_ops)
225227
is_empty_e_ops_mc = isempty(e_ops)
226228
end
227229

228-
saveat = is_empty_e_ops_mc ? t_l : [t_l[end]]
230+
saveat = is_empty_e_ops_mc ? tlist : [tlist[end]]
229231
# We disable the progress bar of the sesolveProblem because we use a global progress bar for all the trajectories
230232
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat, progress_bar = Val(false))
231233
kwargs2 = merge(default_values, kwargs)
@@ -243,9 +245,9 @@ function mcsolveProblem(
243245

244246
params2 = (
245247
expvals = expvals,
246-
e_ops_mc = e_ops2,
248+
e_ops_mc = e_ops_data,
247249
is_empty_e_ops_mc = is_empty_e_ops_mc,
248-
progr_mc = ProgressBar(length(t_l), enable = false),
250+
progr_mc = ProgressBar(length(tlist), enable = false),
249251
traj_rng = rng,
250252
c_ops = c_ops_data,
251253
c_ops_herm = c_ops_herm_data,
@@ -259,53 +261,51 @@ function mcsolveProblem(
259261
params...,
260262
)
261263

262-
return mcsolveProblem(H_eff, ψ0, t_l, alg, H_t, params2, jump_callback; kwargs2...)
264+
return mcsolveProblem(H_eff_evo, ψ0, tlist, alg, params2, jump_callback; kwargs2...)
263265
end
264266

265267
function mcsolveProblem(
266-
H_eff::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject},
267-
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
268-
t_l::AbstractVector,
268+
H_eff_evo::QuantumObjectEvolution{DT1,OperatorQuantumObject},
269+
ψ0::QuantumObject{DT2,KetQuantumObject},
270+
tlist::AbstractVector,
269271
alg::OrdinaryDiffEqAlgorithm,
270-
H_t::Union{Nothing,Function,TimeDependentOperatorSum},
271272
params::NamedTuple,
272273
jump_callback::DiscreteLindbladJumpCallback;
273274
kwargs...,
274-
) where {T1,T2}
275+
) where {DT1,DT2}
275276
cb1 = DiscreteCallback(LindbladJumpDiscreteCondition, LindbladJumpAffect!, save_positions = (false, false))
276-
cb2 = PresetTimeCallback(t_l, _save_func_mcsolve, save_positions = (false, false))
277+
cb2 = PresetTimeCallback(tlist, _save_func_mcsolve, save_positions = (false, false))
277278
kwargs2 = (; kwargs...)
278279
kwargs2 =
279280
haskey(kwargs2, :callback) ? merge(kwargs2, (callback = CallbackSet(cb1, cb2, kwargs2.callback),)) :
280281
merge(kwargs2, (callback = CallbackSet(cb1, cb2),))
281282

282-
return sesolveProblem(H_eff, ψ0, t_l; alg = alg, H_t = H_t, params = params, kwargs2...)
283+
return sesolveProblem(H_eff_evo, ψ0, tlist; alg = alg, params = params, kwargs2...)
283284
end
284285

285286
function mcsolveProblem(
286-
H_eff::QuantumObject{<:AbstractArray,OperatorQuantumObject},
287-
ψ0::QuantumObject{<:AbstractArray,KetQuantumObject},
288-
t_l::AbstractVector,
287+
H_eff_evo::QuantumObjectEvolution{DT1,OperatorQuantumObject},
288+
ψ0::QuantumObject{DT2,KetQuantumObject},
289+
tlist::AbstractVector,
289290
alg::OrdinaryDiffEqAlgorithm,
290-
H_t::Union{Nothing,Function,TimeDependentOperatorSum},
291291
params::NamedTuple,
292292
jump_callback::ContinuousLindbladJumpCallback;
293293
kwargs...,
294-
)
294+
) where {DT1,DT2}
295295
cb1 = ContinuousCallback(
296296
LindbladJumpContinuousCondition,
297297
LindbladJumpAffect!,
298298
nothing,
299299
interp_points = jump_callback.interp_points,
300300
save_positions = (false, false),
301301
)
302-
cb2 = PresetTimeCallback(t_l, _save_func_mcsolve, save_positions = (false, false))
302+
cb2 = PresetTimeCallback(tlist, _save_func_mcsolve, save_positions = (false, false))
303303
kwargs2 = (; kwargs...)
304304
kwargs2 =
305305
haskey(kwargs2, :callback) ? merge(kwargs2, (callback = CallbackSet(cb1, cb2, kwargs2.callback),)) :
306306
merge(kwargs2, (callback = CallbackSet(cb1, cb2),))
307307

308-
return sesolveProblem(H_eff, ψ0, t_l; alg = alg, H_t = H_t, params = params, kwargs2...)
308+
return sesolveProblem(H_eff_evo, ψ0, tlist; alg = alg, params = params, kwargs2...)
309309
end
310310

311311
@doc raw"""
@@ -391,13 +391,12 @@ If the environmental measurements register a quantum jump, the wave function und
391391
- `prob::EnsembleProblem with ODEProblem`: The Ensemble ODEProblem for the Monte Carlo wave function time evolution.
392392
"""
393393
function mcsolveEnsembleProblem(
394-
H::QuantumObject{MT1,OperatorQuantumObject},
395-
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
394+
H::Union{AbstractQuantumObject{DT1,OperatorQuantumObject},Tuple},
395+
ψ0::QuantumObject{DT2,KetQuantumObject},
396396
tlist::AbstractVector,
397397
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
398398
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
399399
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
400-
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
401400
params::NamedTuple = NamedTuple(),
402401
rng::AbstractRNG = default_rng(),
403402
ntraj::Int = 1,
@@ -407,7 +406,7 @@ function mcsolveEnsembleProblem(
407406
output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
408407
progress_bar::Union{Val,Bool} = Val(true),
409408
kwargs...,
410-
) where {MT1<:AbstractMatrix,T2,TJC<:LindbladJumpCallbackType}
409+
) where {DT1,DT2,TJC<:LindbladJumpCallbackType}
411410
progr = ProgressBar(ntraj, enable = getVal(progress_bar))
412411
if ensemble_method isa EnsembleDistributed
413412
progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1))
@@ -429,7 +428,6 @@ function mcsolveEnsembleProblem(
429428
c_ops;
430429
alg = alg,
431430
e_ops = e_ops,
432-
H_t = H_t,
433431
params = merge(params, (global_rng = rng, seeds = seeds)),
434432
rng = rng,
435433
jump_callback = jump_callback,
@@ -533,13 +531,12 @@ If the environmental measurements register a quantum jump, the wave function und
533531
- `sol::TimeEvolutionMCSol`: The solution of the time evolution. See also [`TimeEvolutionMCSol`](@ref)
534532
"""
535533
function mcsolve(
536-
H::QuantumObject{MT1,OperatorQuantumObject},
537-
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
534+
H::Union{AbstractQuantumObject{DT1,OperatorQuantumObject},Tuple},
535+
ψ0::QuantumObject{DT2,KetQuantumObject},
538536
tlist::AbstractVector,
539537
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
540538
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
541539
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
542-
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
543540
params::NamedTuple = NamedTuple(),
544541
rng::AbstractRNG = default_rng(),
545542
ntraj::Int = 1,
@@ -549,15 +546,14 @@ function mcsolve(
549546
output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
550547
progress_bar::Union{Val,Bool} = Val(true),
551548
kwargs...,
552-
) where {MT1<:AbstractMatrix,T2,TJC<:LindbladJumpCallbackType}
549+
) where {DT1,DT2,TJC<:LindbladJumpCallbackType}
553550
ens_prob_mc = mcsolveEnsembleProblem(
554551
H,
555552
ψ0,
556553
tlist,
557554
c_ops;
558555
alg = alg,
559556
e_ops = e_ops,
560-
H_t = H_t,
561557
params = params,
562558
rng = rng,
563559
ntraj = ntraj,
@@ -569,11 +565,12 @@ function mcsolve(
569565
kwargs...,
570566
)
571567

572-
return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
568+
return mcsolve(ens_prob_mc, tlist; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
573569
end
574570

575571
function mcsolve(
576-
ens_prob_mc::EnsembleProblem;
572+
ens_prob_mc::EnsembleProblem,
573+
tlist::AbstractVector;
577574
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
578575
ntraj::Int = 1,
579576
ensemble_method = EnsembleThreads(),
@@ -599,7 +596,7 @@ function mcsolve(
599596

600597
return TimeEvolutionMCSol(
601598
ntraj,
602-
_sol_1.prob.p.times,
599+
tlist,
603600
states,
604601
expvals,
605602
expvals_all,

0 commit comments

Comments
 (0)