Skip to content

Commit c19e0d6

Browse files
Working mcsolve
1 parent 45495be commit c19e0d6

File tree

7 files changed

+72
-65
lines changed

7 files changed

+72
-65
lines changed

src/qobj/quantum_object_evo.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ struct QuantumObjectEvolution{DT<:AbstractSciMLOperator,ObjType<:QuantumObjectTy
9191
end
9292

9393
# Make the QuantumObjectEvolution, with the option to pre-multiply by a scalar
94-
function QuantumObjectEvolution(op_func_list::Tuple, α = true)
94+
function QuantumObjectEvolution(op_func_list::Tuple, α::Union{Nothing,Number} = nothing)
9595
op, data = _generate_data(op_func_list, α)
9696
dims = op.dims
9797
type = op.type
@@ -103,8 +103,15 @@ function QuantumObjectEvolution(op_func_list::Tuple, α = true)
103103
return QuantumObjectEvolution(data, type, dims)
104104
end
105105

106-
QuantumObjectEvolution(op::QuantumObject, α = true) =
107-
QuantumObjectEvolution(MatrixOperator* op.data), op.type, op.dims)
106+
QuantumObjectEvolution(op::QuantumObject, α::Union{Nothing,Number} = nothing) =
107+
QuantumObjectEvolution(_make_SciMLOperator(op, α), op.type, op.dims)
108+
109+
function QuantumObjectEvolution(op::QuantumObjectEvolution, α::Union{Nothing,Number} = nothing)
110+
if α isa Nothing
111+
return QuantumObjectEvolution(op.data, op.type, op.dims)
112+
end
113+
return QuantumObjectEvolution* op.data, op.type, op.dims)
114+
end
108115

109116
@generated function _generate_data(op_func_list::Tuple, α)
110117
op_func_list_types = op_func_list.parameters
@@ -158,10 +165,18 @@ end
158165
function _make_SciMLOperator(op_func::Tuple, α)
159166
T = eltype(op_func[1])
160167
update_func = (a, u, p, t) -> op_func[2](p, t)
168+
if α isa Nothing
169+
return ScalarOperator(zero(T), update_func) * MatrixOperator(op_func[1].data)
170+
end
161171
return ScalarOperator(zero(T), update_func) * MatrixOperator* op_func[1].data)
162172
end
163173

164-
_make_SciMLOperator(op::QuantumObject, α) = MatrixOperator* op.data)
174+
function _make_SciMLOperator(op::QuantumObject, α)
175+
if α isa Nothing
176+
return MatrixOperator(op.data)
177+
end
178+
return MatrixOperator* op.data)
179+
end
165180

166181
function (QO::QuantumObjectEvolution)(p, t)
167182
# We put 0 in the place of `u` because the time-dependence doesn't depend on the state

src/qobj/synonyms.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@ Note that this functions is same as `QuantumObject(A; type=type, dims=dims)`.
1818
Qobj(A; kwargs...) = QuantumObject(A; kwargs...)
1919

2020
@doc raw"""
21-
QobjEvo(op_func_list::Union{Tuple,QuantumObject}, α::Real=true)
21+
QobjEvo(op_func_list::Union{Tuple,AbstractQuantumObject}, α::Union{Nothing,Number}=nothing)
2222
2323
Generate [`QuantumObjectEvolution`](@ref)
2424
2525
Note that this functions is same as `QuantumObjectEvolution(op_func_list)`. If `α` is provided, all the operators in `op_func_list` will be pre-multiplied by `α`.
2626
"""
27-
QobjEvo(op_func_list::Union{Tuple,QuantumObject}, α = true) = QuantumObjectEvolution(op_func_list, α)
27+
QobjEvo(op_func_list::Union{Tuple,AbstractQuantumObject}, α::Union{Nothing,Number} = nothing) =
28+
QuantumObjectEvolution(op_func_list, α)
2829

2930
@doc raw"""
3031
shape(A::QuantumObject)

src/time_evolution/mcsolve.jl

Lines changed: 35 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -191,41 +191,38 @@ If the environmental measurements register a quantum jump, the wave function und
191191
- `prob::ODEProblem`: The ODEProblem for the Monte Carlo wave function time evolution.
192192
"""
193193
function mcsolveProblem(
194-
H::QuantumObject{MT1,OperatorQuantumObject},
195-
ψ0::QuantumObject{<:AbstractArray,KetQuantumObject},
194+
H::Union{AbstractQuantumObject{DT1,OperatorQuantumObject},Tuple},
195+
ψ0::QuantumObject{DT2,KetQuantumObject},
196196
tlist::AbstractVector,
197197
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
198198
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
199199
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
200-
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
201200
params::NamedTuple = NamedTuple(),
202201
rng::AbstractRNG = default_rng(),
203202
jump_callback::TJC = ContinuousLindbladJumpCallback(),
204203
kwargs...,
205-
) where {MT1<:AbstractMatrix,TJC<:LindbladJumpCallbackType}
206-
check_dims(H, ψ0)
207-
204+
) where {DT1,DT2,TJC<:LindbladJumpCallbackType}
208205
haskey(kwargs, :save_idxs) &&
209206
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))
210207

211208
c_ops isa Nothing &&
212209
throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead."))
213210

214-
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
211+
tlist = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
215212

216-
H_eff = H - 1im * mapreduce(op -> op' * op, +, c_ops) / 2
213+
H_eff_evo = QobjEvo(H) + QobjEvo(mapreduce(op -> op' * op, +, c_ops) / 2, -1im)
217214

218215
if e_ops isa Nothing
219-
expvals = Array{ComplexF64}(undef, 0, length(t_l))
216+
expvals = Array{ComplexF64}(undef, 0, length(tlist))
220217
is_empty_e_ops_mc = true
221-
e_ops2 = MT1[]
218+
e_ops_data = ()
222219
else
223-
expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))
224-
e_ops2 = get_data.(e_ops)
220+
expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist))
221+
e_ops_data = get_data.(e_ops)
225222
is_empty_e_ops_mc = isempty(e_ops)
226223
end
227224

228-
saveat = is_empty_e_ops_mc ? t_l : [t_l[end]]
225+
saveat = is_empty_e_ops_mc ? tlist : [tlist[end]]
229226
# We disable the progress bar of the sesolveProblem because we use a global progress bar for all the trajectories
230227
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat, progress_bar = Val(false))
231228
kwargs2 = merge(default_values, kwargs)
@@ -243,9 +240,9 @@ function mcsolveProblem(
243240

244241
params2 = (
245242
expvals = expvals,
246-
e_ops_mc = e_ops2,
243+
e_ops_mc = e_ops_data,
247244
is_empty_e_ops_mc = is_empty_e_ops_mc,
248-
progr_mc = ProgressBar(length(t_l), enable = false),
245+
progr_mc = ProgressBar(length(tlist), enable = false),
249246
traj_rng = rng,
250247
c_ops = c_ops_data,
251248
c_ops_herm = c_ops_herm_data,
@@ -259,53 +256,51 @@ function mcsolveProblem(
259256
params...,
260257
)
261258

262-
return mcsolveProblem(H_eff, ψ0, t_l, alg, H_t, params2, jump_callback; kwargs2...)
259+
return mcsolveProblem(H_eff_evo, ψ0, tlist, alg, params2, jump_callback; kwargs2...)
263260
end
264261

265262
function mcsolveProblem(
266-
H_eff::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject},
267-
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
268-
t_l::AbstractVector,
263+
H_eff_evo::QuantumObjectEvolution{DT1,OperatorQuantumObject},
264+
ψ0::QuantumObject{DT2,KetQuantumObject},
265+
tlist::AbstractVector,
269266
alg::OrdinaryDiffEqAlgorithm,
270-
H_t::Union{Nothing,Function,TimeDependentOperatorSum},
271267
params::NamedTuple,
272268
jump_callback::DiscreteLindbladJumpCallback;
273269
kwargs...,
274-
) where {T1,T2}
270+
) where {DT1,DT2}
275271
cb1 = DiscreteCallback(LindbladJumpDiscreteCondition, LindbladJumpAffect!, save_positions = (false, false))
276-
cb2 = PresetTimeCallback(t_l, _save_func_mcsolve, save_positions = (false, false))
272+
cb2 = PresetTimeCallback(tlist, _save_func_mcsolve, save_positions = (false, false))
277273
kwargs2 = (; kwargs...)
278274
kwargs2 =
279275
haskey(kwargs2, :callback) ? merge(kwargs2, (callback = CallbackSet(cb1, cb2, kwargs2.callback),)) :
280276
merge(kwargs2, (callback = CallbackSet(cb1, cb2),))
281277

282-
return sesolveProblem(H_eff, ψ0, t_l; alg = alg, H_t = H_t, params = params, kwargs2...)
278+
return sesolveProblem(H_eff_evo, ψ0, tlist; alg = alg, params = params, kwargs2...)
283279
end
284280

285281
function mcsolveProblem(
286-
H_eff::QuantumObject{<:AbstractArray,OperatorQuantumObject},
287-
ψ0::QuantumObject{<:AbstractArray,KetQuantumObject},
288-
t_l::AbstractVector,
282+
H_eff_evo::QuantumObjectEvolution{DT1,OperatorQuantumObject},
283+
ψ0::QuantumObject{DT2,KetQuantumObject},
284+
tlist::AbstractVector,
289285
alg::OrdinaryDiffEqAlgorithm,
290-
H_t::Union{Nothing,Function,TimeDependentOperatorSum},
291286
params::NamedTuple,
292287
jump_callback::ContinuousLindbladJumpCallback;
293288
kwargs...,
294-
)
289+
) where {DT1,DT2}
295290
cb1 = ContinuousCallback(
296291
LindbladJumpContinuousCondition,
297292
LindbladJumpAffect!,
298293
nothing,
299294
interp_points = jump_callback.interp_points,
300295
save_positions = (false, false),
301296
)
302-
cb2 = PresetTimeCallback(t_l, _save_func_mcsolve, save_positions = (false, false))
297+
cb2 = PresetTimeCallback(tlist, _save_func_mcsolve, save_positions = (false, false))
303298
kwargs2 = (; kwargs...)
304299
kwargs2 =
305300
haskey(kwargs2, :callback) ? merge(kwargs2, (callback = CallbackSet(cb1, cb2, kwargs2.callback),)) :
306301
merge(kwargs2, (callback = CallbackSet(cb1, cb2),))
307302

308-
return sesolveProblem(H_eff, ψ0, t_l; alg = alg, H_t = H_t, params = params, kwargs2...)
303+
return sesolveProblem(H_eff_evo, ψ0, tlist; alg = alg, params = params, kwargs2...)
309304
end
310305

311306
@doc raw"""
@@ -391,13 +386,12 @@ If the environmental measurements register a quantum jump, the wave function und
391386
- `prob::EnsembleProblem with ODEProblem`: The Ensemble ODEProblem for the Monte Carlo wave function time evolution.
392387
"""
393388
function mcsolveEnsembleProblem(
394-
H::QuantumObject{MT1,OperatorQuantumObject},
395-
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
389+
H::Union{AbstractQuantumObject{DT1,OperatorQuantumObject},Tuple},
390+
ψ0::QuantumObject{DT2,KetQuantumObject},
396391
tlist::AbstractVector,
397392
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
398393
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
399394
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
400-
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
401395
params::NamedTuple = NamedTuple(),
402396
rng::AbstractRNG = default_rng(),
403397
ntraj::Int = 1,
@@ -407,7 +401,7 @@ function mcsolveEnsembleProblem(
407401
output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
408402
progress_bar::Union{Val,Bool} = Val(true),
409403
kwargs...,
410-
) where {MT1<:AbstractMatrix,T2,TJC<:LindbladJumpCallbackType}
404+
) where {DT1,DT2,TJC<:LindbladJumpCallbackType}
411405
progr = ProgressBar(ntraj, enable = getVal(progress_bar))
412406
if ensemble_method isa EnsembleDistributed
413407
progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1))
@@ -429,7 +423,6 @@ function mcsolveEnsembleProblem(
429423
c_ops;
430424
alg = alg,
431425
e_ops = e_ops,
432-
H_t = H_t,
433426
params = merge(params, (global_rng = rng, seeds = seeds)),
434427
rng = rng,
435428
jump_callback = jump_callback,
@@ -533,13 +526,12 @@ If the environmental measurements register a quantum jump, the wave function und
533526
- `sol::TimeEvolutionMCSol`: The solution of the time evolution. See also [`TimeEvolutionMCSol`](@ref)
534527
"""
535528
function mcsolve(
536-
H::QuantumObject{MT1,OperatorQuantumObject},
537-
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
529+
H::Union{AbstractQuantumObject{DT1,OperatorQuantumObject},Tuple},
530+
ψ0::QuantumObject{DT2,KetQuantumObject},
538531
tlist::AbstractVector,
539532
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
540533
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
541534
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
542-
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
543535
params::NamedTuple = NamedTuple(),
544536
rng::AbstractRNG = default_rng(),
545537
ntraj::Int = 1,
@@ -549,15 +541,14 @@ function mcsolve(
549541
output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
550542
progress_bar::Union{Val,Bool} = Val(true),
551543
kwargs...,
552-
) where {MT1<:AbstractMatrix,T2,TJC<:LindbladJumpCallbackType}
544+
) where {DT1,DT2,TJC<:LindbladJumpCallbackType}
553545
ens_prob_mc = mcsolveEnsembleProblem(
554546
H,
555547
ψ0,
556548
tlist,
557549
c_ops;
558550
alg = alg,
559551
e_ops = e_ops,
560-
H_t = H_t,
561552
params = params,
562553
rng = rng,
563554
ntraj = ntraj,
@@ -569,11 +560,12 @@ function mcsolve(
569560
kwargs...,
570561
)
571562

572-
return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
563+
return mcsolve(ens_prob_mc, tlist; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
573564
end
574565

575566
function mcsolve(
576-
ens_prob_mc::EnsembleProblem;
567+
ens_prob_mc::EnsembleProblem,
568+
tlist::AbstractVector;
577569
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
578570
ntraj::Int = 1,
579571
ensemble_method = EnsembleThreads(),
@@ -599,7 +591,7 @@ function mcsolve(
599591

600592
return TimeEvolutionMCSol(
601593
ntraj,
602-
_sol_1.prob.p.times,
594+
tlist,
603595
states,
604596
expvals,
605597
expvals_all,

src/time_evolution/sesolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ Generates the ODEProblem for the Schrödinger time evolution of a quantum system
7777
- `prob`: The `ODEProblem` for the Schrödinger time evolution of the system.
7878
"""
7979
function sesolveProblem(
80-
H::Union{QuantumObject{DT1,OperatorQuantumObject},Tuple},
80+
H::Union{AbstractQuantumObject{DT1,OperatorQuantumObject},Tuple},
8181
ψ0::QuantumObject{DT2,KetQuantumObject},
8282
tlist::AbstractVector;
8383
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
@@ -169,7 +169,7 @@ Time evolution of a closed quantum system using the Schrödinger equation:
169169
- `sol::TimeEvolutionSol`: The solution of the time evolution. See also [`TimeEvolutionSol`](@ref)
170170
"""
171171
function sesolve(
172-
H::Union{QuantumObject{DT1,OperatorQuantumObject},Tuple},
172+
H::Union{AbstractQuantumObject{DT1,OperatorQuantumObject},Tuple},
173173
ψ0::QuantumObject{DT2,KetQuantumObject},
174174
tlist::AbstractVector;
175175
alg::OrdinaryDiffEqAlgorithm = Tsit5(),

src/time_evolution/time_evolution.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,15 @@ A structure storing the results and some information from solving quantum trajec
7171
- `reltol::Real`: The relative tolerance which is used during the solving process.
7272
"""
7373
struct TimeEvolutionMCSol{
74-
TT<:Vector{<:Real},
74+
TT<:AbstractVector{<:Real},
7575
TS<:AbstractVector,
7676
TE<:Matrix{ComplexF64},
7777
TEA<:Array{ComplexF64,3},
7878
TJT<:Vector{<:Vector{<:Real}},
7979
TJW<:Vector{<:Vector{<:Integer}},
80+
AlgT<:OrdinaryDiffEqAlgorithm,
81+
AT<:Real,
82+
RT<:Real,
8083
}
8184
ntraj::Int
8285
times::TT
@@ -86,9 +89,9 @@ struct TimeEvolutionMCSol{
8689
jump_times::TJT
8790
jump_which::TJW
8891
converged::Bool
89-
alg::OrdinaryDiffEqAlgorithm
90-
abstol::Real
91-
reltol::Real
92+
alg::AlgT
93+
abstol::AT
94+
reltol::RT
9295
end
9396

9497
function Base.show(io::IO, sol::TimeEvolutionMCSol)

src/time_evolution/time_evolution_dynamical.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,6 @@ function dsf_mcsolveEnsembleProblem(
618618
dsf_params::NamedTuple = NamedTuple();
619619
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
620620
e_ops::Function = (op_list, p) -> Vector{TOl}([]),
621-
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
622621
params::NamedTuple = NamedTuple(),
623622
ntraj::Int = 1,
624623
ensemble_method = EnsembleThreads(),
@@ -670,7 +669,6 @@ function dsf_mcsolveEnsembleProblem(
670669
c_ops₀;
671670
e_ops = e_ops₀,
672671
alg = alg,
673-
H_t = H_t,
674672
params = params2,
675673
ntraj = ntraj,
676674
ensemble_method = ensemble_method,
@@ -690,7 +688,6 @@ end
690688
dsf_params::NamedTuple=NamedTuple();
691689
alg::OrdinaryDiffEqAlgorithm=Tsit5(),
692690
e_ops::Function=(op_list,p) -> Vector{TOl}([]),
693-
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
694691
params::NamedTuple=NamedTuple(),
695692
δα_list::Vector{<:Real}=fill(0.2, length(op_list)),
696693
ntraj::Int=1,
@@ -716,7 +713,6 @@ function dsf_mcsolve(
716713
dsf_params::NamedTuple = NamedTuple();
717714
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
718715
e_ops::Function = (op_list, p) -> Vector{TOl}([]),
719-
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
720716
params::NamedTuple = NamedTuple(),
721717
δα_list::Vector{<:Real} = fill(0.2, length(op_list)),
722718
ntraj::Int = 1,
@@ -736,7 +732,6 @@ function dsf_mcsolve(
736732
dsf_params;
737733
alg = alg,
738734
e_ops = e_ops,
739-
H_t = H_t,
740735
params = params,
741736
ntraj = ntraj,
742737
ensemble_method = ensemble_method,
@@ -747,5 +742,5 @@ function dsf_mcsolve(
747742
kwargs...,
748743
)
749744

750-
return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
745+
return mcsolve(ens_prob_mc, t_l; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
751746
end

0 commit comments

Comments
 (0)