Skip to content

Commit 48f8c41

Browse files
Add progress_bar in mcsolve, ssesolve and dsf_mcsolve (#254)
1 parent bde9142 commit 48f8c41

File tree

10 files changed

+136
-84
lines changed

10 files changed

+136
-84
lines changed

Project.toml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
99
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
1010
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
11+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1112
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1213
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1314
IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895"
@@ -37,23 +38,24 @@ CUDA = "5"
3738
DiffEqBase = "6"
3839
DiffEqCallbacks = "2 - 3.1, 3.8, 4"
3940
DiffEqNoiseProcess = "5"
41+
Distributed = "1"
4042
FFTW = "1.5"
4143
Graphs = "1.7"
4244
IncompleteLU = "0.2"
43-
LinearAlgebra = "<0.0.1, 1"
45+
LinearAlgebra = "1"
4446
LinearSolve = "2"
4547
OrdinaryDiffEqCore = "1"
4648
OrdinaryDiffEqTsit5 = "1"
47-
Pkg = "<0.0.1, 1"
48-
Random = "<0.0.1, 1"
49+
Pkg = "1"
50+
Random = "1"
4951
Reexport = "1"
5052
SciMLBase = "2"
5153
SciMLOperators = "0.3"
52-
SparseArrays = "<0.0.1, 1"
54+
SparseArrays = "1"
5355
SpecialFunctions = "2"
5456
StaticArraysCore = "1"
5557
StochasticDiffEq = "6"
56-
Test = "<0.0.1, 1"
58+
Test = "1"
5759
julia = "1.10"
5860

5961
[extras]

src/QuantumToolbox.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import DiffEqNoiseProcess: RealWienerProcess
4040

4141
# other dependencies (in alphabetical order)
4242
import ArrayInterface: allowed_getindex, allowed_setindex!
43+
import Distributed: RemoteChannel
4344
import FFTW: fft, fftshift
4445
import Graphs: connected_components, DiGraph
4546
import IncompleteLU: ilu

src/qobj/operators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ function tunneling(N::Int, m::Int = 1; sparse::Union{Bool,Val} = Val(false))
511511
(m < 1) && throw(ArgumentError("The number of excitations (m) cannot be less than 1"))
512512

513513
data = ones(ComplexF64, N - m)
514-
if getVal(makeVal(sparse))
514+
if getVal(sparse)
515515
return QuantumObject(spdiagm(m => data, -m => data); type = Operator, dims = N)
516516
else
517517
return QuantumObject(diagm(m => data, -m => data); type = Operator, dims = N)

src/qobj/states.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ It is also possible to specify the list of dimensions `dims` if different subsys
3434
If you want to keep type stability, it is recommended to use `fock(N, j, dims=dims, sparse=Val(sparse))` instead of `fock(N, j, dims=dims, sparse=sparse)`. Consider also to use `dims` as a `Tuple` or `SVector` instead of `Vector`. See [this link](https://docs.julialang.org/en/v1/manual/performance-tips/#man-performance-value-type) and the [related Section](@ref doc:Type-Stability) about type stability for more details.
3535
"""
3636
function fock(N::Int, j::Int = 0; dims::Union{Int,AbstractVector{Int},Tuple} = N, sparse::Union{Bool,Val} = Val(false))
37-
if getVal(makeVal(sparse))
37+
if getVal(sparse)
3838
array = sparsevec([j + 1], [1.0 + 0im], N)
3939
else
4040
array = zeros(ComplexF64, N)
@@ -130,7 +130,7 @@ function thermal_dm(N::Int, n::Real; sparse::Union{Bool,Val} = Val(false))
130130
β = log(1.0 / n + 1.0)
131131
N_list = Array{Float64}(0:N-1)
132132
data = exp.(-β .* N_list)
133-
if getVal(makeVal(sparse))
133+
if getVal(sparse)
134134
return QuantumObject(spdiagm(0 => data ./ sum(data)), Operator, N)
135135
else
136136
return QuantumObject(diagm(0 => data ./ sum(data)), Operator, N)

src/time_evolution/mcsolve.jl

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ end
8383
function _mcsolve_output_func(sol, i)
8484
resize!(sol.prob.p.jump_times, sol.prob.p.jump_times_which_idx[] - 1)
8585
resize!(sol.prob.p.jump_which, sol.prob.p.jump_times_which_idx[] - 1)
86+
put!(sol.prob.p.progr_channel, true)
8687
return (sol, false)
8788
end
8889

@@ -204,7 +205,8 @@ function mcsolveProblem(
204205
end
205206

206207
saveat = e_ops isa Nothing ? t_l : [t_l[end]]
207-
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
208+
# We disable the progress bar of the sesolveProblem because we use a global progress bar for all the trajectories
209+
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat, progress_bar = Val(false))
208210
kwargs2 = merge(default_values, kwargs)
209211

210212
cache_mc = similar(ψ0.data)
@@ -396,15 +398,20 @@ end
396398
mcsolve(H::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject},
397399
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
398400
tlist::AbstractVector,
399-
c_ops::Union{Nothing,AbstractVector,Tuple}=nothing;
400-
alg::OrdinaryDiffEqAlgorithm=Tsit5(),
401-
e_ops::Union{Nothing,AbstractVector,Tuple}=nothing,
402-
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
403-
params::NamedTuple=NamedTuple(),
404-
ntraj::Int=1,
405-
ensemble_method=EnsembleThreads(),
406-
jump_callback::TJC=ContinuousLindbladJumpCallback(),
407-
kwargs...)
401+
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
402+
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
403+
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
404+
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
405+
params::NamedTuple = NamedTuple(),
406+
seeds::Union{Nothing,Vector{Int}} = nothing,
407+
ntraj::Int = 1,
408+
ensemble_method = EnsembleThreads(),
409+
jump_callback::TJC = ContinuousLindbladJumpCallback(),
410+
prob_func::Function = _mcsolve_prob_func,
411+
output_func::Function = _mcsolve_output_func,
412+
progress_bar::Union{Val,Bool} = Val(true),
413+
kwargs...,
414+
)
408415
409416
Time evolution of an open quantum system using quantum trajectories.
410417
@@ -457,6 +464,7 @@ If the environmental measurements register a quantum jump, the wave function und
457464
- `prob_func::Function`: Function to use for generating the ODEProblem.
458465
- `output_func::Function`: Function to use for generating the output of a single trajectory.
459466
- `kwargs...`: Additional keyword arguments to pass to the solver.
467+
- `progress_bar::Union{Val,Bool}`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
460468
461469
# Notes
462470
@@ -486,29 +494,42 @@ function mcsolve(
486494
jump_callback::TJC = ContinuousLindbladJumpCallback(),
487495
prob_func::Function = _mcsolve_prob_func,
488496
output_func::Function = _mcsolve_output_func,
497+
progress_bar::Union{Val,Bool} = Val(true),
489498
kwargs...,
490499
) where {MT1<:AbstractMatrix,T2,TJC<:LindbladJumpCallbackType}
491500
if !isnothing(seeds) && length(seeds) != ntraj
492501
throw(ArgumentError("Length of seeds must match ntraj ($ntraj), but got $(length(seeds))"))
493502
end
494503

495-
ens_prob_mc = mcsolveEnsembleProblem(
496-
H,
497-
ψ0,
498-
tlist,
499-
c_ops;
500-
alg = alg,
501-
e_ops = e_ops,
502-
H_t = H_t,
503-
params = params,
504-
seeds = seeds,
505-
jump_callback = jump_callback,
506-
prob_func = prob_func,
507-
output_func = output_func,
508-
kwargs...,
509-
)
504+
progr = ProgressBar(ntraj, enable = getVal(progress_bar))
505+
progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1))
506+
@async while take!(progr_channel)
507+
next!(progr)
508+
end
510509

511-
return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
510+
# Stop the async task if an error occurs
511+
try
512+
ens_prob_mc = mcsolveEnsembleProblem(
513+
H,
514+
ψ0,
515+
tlist,
516+
c_ops;
517+
alg = alg,
518+
e_ops = e_ops,
519+
H_t = H_t,
520+
params = merge(params, (progr_channel = progr_channel,)),
521+
seeds = seeds,
522+
jump_callback = jump_callback,
523+
prob_func = prob_func,
524+
output_func = output_func,
525+
kwargs...,
526+
)
527+
528+
return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
529+
catch e
530+
put!(progr_channel, false)
531+
rethrow()
532+
end
512533
end
513534

514535
function mcsolve(
@@ -518,6 +539,9 @@ function mcsolve(
518539
ensemble_method = EnsembleThreads(),
519540
)
520541
sol = solve(ens_prob_mc, alg, ensemble_method, trajectories = ntraj)
542+
543+
put!(sol[:, 1].prob.p.progr_channel, false)
544+
521545
_sol_1 = sol[:, 1]
522546

523547
expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...)

src/time_evolution/mesolve.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,13 @@ function mesolveProblem(
120120
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))
121121

122122
is_time_dependent = !(H_t isa Nothing)
123-
progress_bar_val = makeVal(progress_bar)
124123

125124
ρ0 = sparse_to_dense(_CType(ψ0), mat2vec(ket2dm(ψ0).data)) # Convert it to dense vector with complex element type
126125

127126
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
128127

129128
L = liouvillian(H, c_ops).data
130-
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))
129+
progr = ProgressBar(length(t_l), enable = getVal(progress_bar))
131130

132131
if e_ops isa Nothing
133132
expvals = Array{ComplexF64}(undef, 0, length(t_l))
@@ -158,7 +157,7 @@ function mesolveProblem(
158157
saveat = e_ops isa Nothing ? t_l : [t_l[end]]
159158
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
160159
kwargs2 = merge(default_values, kwargs)
161-
kwargs3 = _generate_mesolve_kwargs(e_ops, progress_bar_val, t_l, kwargs2)
160+
kwargs3 = _generate_mesolve_kwargs(e_ops, makeVal(progress_bar), t_l, kwargs2)
162161

163162
dudt! = is_time_dependent ? mesolve_td_dudt! : mesolve_ti_dudt!
164163

@@ -241,7 +240,7 @@ function mesolve(
241240
e_ops = e_ops,
242241
H_t = H_t,
243242
params = params,
244-
progress_bar = makeVal(progress_bar),
243+
progress_bar = progress_bar,
245244
kwargs...,
246245
)
247246

src/time_evolution/sesolve.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,13 @@ function sesolveProblem(
101101
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))
102102

103103
is_time_dependent = !(H_t isa Nothing)
104-
progress_bar_val = makeVal(progress_bar)
105104

106105
ϕ0 = sparse_to_dense(_CType(ψ0), get_data(ψ0)) # Convert it to dense vector with complex element type
107106

108107
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
109108

110109
U = -1im * get_data(H)
111-
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))
110+
progr = ProgressBar(length(t_l), enable = getVal(progress_bar))
112111

113112
if e_ops isa Nothing
114113
expvals = Array{ComplexF64}(undef, 0, length(t_l))
@@ -135,7 +134,7 @@ function sesolveProblem(
135134
saveat = e_ops isa Nothing ? t_l : [t_l[end]]
136135
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
137136
kwargs2 = merge(default_values, kwargs)
138-
kwargs3 = _generate_sesolve_kwargs(e_ops, progress_bar_val, t_l, kwargs2)
137+
kwargs3 = _generate_sesolve_kwargs(e_ops, makeVal(progress_bar), t_l, kwargs2)
139138

140139
dudt! = is_time_dependent ? sesolve_td_dudt! : sesolve_ti_dudt!
141140

@@ -203,7 +202,7 @@ function sesolve(
203202
e_ops = e_ops,
204203
H_t = H_t,
205204
params = params,
206-
progress_bar = makeVal(progress_bar),
205+
progress_bar = progress_bar,
207206
kwargs...,
208207
)
209208

0 commit comments

Comments
 (0)