Skip to content

Commit 8e72246

Browse files
Remove ProgressBar from ODE parameters
1 parent eb05e53 commit 8e72246

File tree

6 files changed

+94
-70
lines changed

6 files changed

+94
-70
lines changed

src/qobj/quantum_object_evo.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ function QuantumObjectEvolution(
214214
return QuantumObjectEvolution* op.data, type, op.dims)
215215
end
216216

217+
_all_equal(dims) = all(x -> x == first(dims), dims)
218+
217219
#=
218220
_QobjEvo_generate_data(op_func_list::Tuple, α; f::Function=identity)
219221
@@ -269,7 +271,7 @@ Parse the `op_func_list` and generate the data for the `QuantumObjectEvolution`
269271
quote
270272
dims = tuple($(dims_expr...))
271273

272-
length(unique(dims)) == 1 || throw(ArgumentError("The dimensions of the operators must be the same."))
274+
_all_equal(dims) || throw(ArgumentError("The dimensions of the operators must be the same."))
273275

274276
data_expr_const = $qobj_expr_const isa Integer ? $qobj_expr_const : _make_SciMLOperator($qobj_expr_const, α)
275277

src/time_evolution/callback_helpers.jl

Lines changed: 71 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,49 +4,58 @@ This file contains helper functions for callbacks. The affect! function are defi
44

55
########## SESOLVE ##########
66

7-
struct SaveFuncSESolve{T1,T2}
8-
e_ops::T1
9-
is_empty_e_ops::T2
7+
struct SaveFuncSESolve{TE,PT<:Union{Nothing,ProgressBar},IT}
8+
e_ops::TE
9+
progr::PT
10+
iter::IT
1011
end
1112

12-
(f::SaveFuncSESolve)(integrator) = _save_func_sesolve(integrator, f.e_ops, f.is_empty_e_ops)
13-
(f::SaveFuncSESolve{Nothing})(integrator) = _save_func_sesolve(integrator)
13+
(f::SaveFuncSESolve)(integrator) = _save_func_sesolve(integrator, f.e_ops, f.progr, f.iter)
14+
(f::SaveFuncSESolve{Nothing})(integrator) = _save_func_sesolve(integrator, f.progr)
1415

1516
##
1617

1718
# When e_ops is Nothing
18-
function _save_func_sesolve(integrator)
19-
next!(integrator.p.progr)
19+
function _save_func_sesolve(integrator, progr)
20+
next!(progr)
21+
u_modified!(integrator, false)
22+
return nothing
23+
end
24+
25+
# When progr is Nothing
26+
function _save_func_sesolve(integrator, progr::Nothing)
2027
u_modified!(integrator, false)
2128
return nothing
2229
end
2330

2431
# When e_ops is a list of operators
25-
function _save_func_sesolve(integrator, e_ops, is_empty_e_ops)
32+
function _save_func_sesolve(integrator, e_ops, progr, iter)
2633
expvals = integrator.p.expvals
27-
progr = integrator.p.progr
28-
if !is_empty_e_ops
29-
ψ = integrator.u
30-
_expect = op -> dot(ψ, op, ψ)
31-
@. expvals[:, progr.counter[]+1] = _expect(e_ops)
32-
end
33-
return _save_func_sesolve(integrator)
34+
ψ = integrator.u
35+
_expect = op -> dot(ψ, op, ψ)
36+
@. expvals[:, iter[]] = _expect(e_ops)
37+
iter[] += 1
38+
39+
return _save_func_sesolve(integrator, progr)
3440
end
3541

36-
function _generate_sesolve_callback(e_ops, tlist)
37-
is_empty_e_ops = e_ops isa Nothing ? true : isempty(e_ops)
38-
_save_affect! = SaveFuncSESolve(get_data.(e_ops), is_empty_e_ops)
42+
function _generate_sesolve_callback(e_ops, tlist, progress_bar)
43+
e_ops_data = e_ops isa Nothing ? nothing : get_data.(e_ops)
44+
45+
progr = getVal(progress_bar) ? ProgressBar(length(tlist), enable = getVal(progress_bar)) : nothing
46+
47+
_save_affect! = SaveFuncSESolve(e_ops_data, progr, Ref(1))
3948
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
4049
end
4150

4251
########## MCSOLVE ##########
4352

44-
struct SaveFuncMCSolve{T1,T2}
45-
e_ops::T1
46-
is_empty_e_ops::T2
53+
struct SaveFuncMCSolve{TE,IT}
54+
e_ops::TE
55+
iter::IT
4756
end
4857

49-
(f::SaveFuncMCSolve)(integrator) = _save_func_mcsolve(integrator, f.e_ops, f.is_empty_e_ops)
58+
(f::SaveFuncMCSolve)(integrator) = _save_func_mcsolve(integrator, f.e_ops, f.iter)
5059

5160
struct LindbladJump{T1,T2}
5261
c_ops::T1
@@ -57,18 +66,17 @@ end
5766

5867
##
5968

60-
function _save_func_mcsolve(integrator, e_ops, is_empty_e_ops)
69+
function _save_func_mcsolve(integrator, e_ops, iter)
6170
expvals = integrator.p.expvals
62-
progr = integrator.p.progr
6371
cache_mc = integrator.p.mcsolve_params.cache_mc
64-
if !is_empty_e_ops
65-
copyto!(cache_mc, integrator.u)
66-
normalize!(cache_mc)
67-
ψ = cache_mc
68-
_expect = op -> dot(ψ, op, ψ)
69-
@. expvals[:, progr.counter[]+1] = _expect(e_ops)
70-
end
71-
next!(progr)
72+
73+
copyto!(cache_mc, integrator.u)
74+
normalize!(cache_mc)
75+
ψ = cache_mc
76+
_expect = op -> dot(ψ, op, ψ)
77+
@. expvals[:, iter[]] = _expect(e_ops)
78+
iter[] += 1
79+
7280
u_modified!(integrator, false)
7381
return nothing
7482
end
@@ -97,8 +105,7 @@ function _generate_mcsolve_kwargs(e_ops, tlist, c_ops, jump_callback, kwargs)
97105
merge(kwargs, (callback = cb1,))
98106
return kwargs2
99107
else
100-
is_empty_e_ops = isempty(e_ops)
101-
_save_affect! = SaveFuncMCSolve(get_data.(e_ops), is_empty_e_ops)
108+
_save_affect! = SaveFuncMCSolve(get_data.(e_ops), Ref(1))
102109
cb2 = _PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
103110
kwargs2 =
104111
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, cb2, kwargs.callback),)) :
@@ -171,6 +178,36 @@ function _mcsolve_get_e_ops(integrator::AbstractODEIntegrator)
171178
return cb.affect!.e_ops
172179
end
173180

181+
#=
182+
_mcsolve_callbacks_new_iter(prob, tlist)
183+
184+
Return the same callbacks of the `prob`, but with the `iter` variable reinitialized to 1.
185+
=#
186+
function _mcsolve_callbacks_new_iter(prob, tlist)
187+
cb = prob.kwargs[:callback]
188+
return _mcsolve_callbacks_new_iter(cb, tlist)
189+
end
190+
function _mcsolve_callbacks_new_iter(cb::CallbackSet, tlist)
191+
cb_continuous = cb.continuous_callbacks
192+
cb_discrete = cb.discrete_callbacks
193+
194+
if length(cb_continuous) > 0
195+
idx = 1
196+
e_ops = cb_discrete[idx].affect!.e_ops
197+
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1))
198+
cb_save = _PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
199+
return CallbackSet(cb_continuous..., cb_save, cb_discrete[2:end]...)
200+
else
201+
idx = 2
202+
e_ops = cb_discrete[idx].affect!.e_ops
203+
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1))
204+
cb_save = _PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
205+
return CallbackSet(cb_continuous..., cb_discrete[1], cb_save, cb_discrete[3:end]...)
206+
end
207+
end
208+
_mcsolve_callbacks_new_iter(cb::ContinuousCallback, tlist) = cb
209+
_mcsolve_callbacks_new_iter(cb::DiscreteCallback, tlist) = cb
210+
174211
## Temporary function to avoid errors. Waiting for the PR In DiffEqCallbacks.jl to be merged.
175212

176213
import SciMLBase: INITIALIZE_DEFAULT, add_tstop!

src/time_evolution/mcsolve.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
export mcsolveProblem, mcsolveEnsembleProblem, mcsolve
22
export ContinuousLindbladJumpCallback, DiscreteLindbladJumpCallback
33

4-
function _mcsolve_prob_func(prob, i, repeat, global_rng, seeds)
4+
function _mcsolve_prob_func(prob, i, repeat, global_rng, seeds, tlist)
55
params = prob.p
66

77
seed = seeds[i]
88
traj_rng = typeof(global_rng)()
99
seed!(traj_rng, seed)
1010

1111
expvals = similar(params.expvals)
12-
progr = ProgressBar(size(expvals, 2), enable = false)
1312

1413
T = eltype(expvals)
1514

@@ -24,16 +23,17 @@ function _mcsolve_prob_func(prob, i, repeat, global_rng, seeds)
2423
jump_times_which_idx = T[1],
2524
)
2625

27-
p = TimeEvolutionParameters(params.params, expvals, progr, mcsolve_params)
26+
p = TimeEvolutionParameters(params.params, expvals, mcsolve_params)
2827

2928
f = deepcopy(prob.f.f)
29+
cb = _mcsolve_callbacks_new_iter(prob, tlist)
3030

31-
return remake(prob, f = f, p = p)
31+
return remake(prob, f = f, p = p, callback = cb)
3232
end
3333

34-
function _mcsolve_dispatch_prob_func(rng, ntraj)
34+
function _mcsolve_dispatch_prob_func(rng, ntraj, tlist)
3535
seeds = map(i -> rand(rng, UInt64), 1:ntraj)
36-
return (prob, i, repeat) -> _mcsolve_prob_func(prob, i, repeat, rng, seeds)
36+
return (prob, i, repeat) -> _mcsolve_prob_func(prob, i, repeat, rng, seeds, tlist)
3737
end
3838

3939
# Standard output function
@@ -229,7 +229,7 @@ function mcsolveProblem(
229229
jump_which = jump_which,
230230
jump_times_which_idx = jump_times_which_idx,
231231
)
232-
p = TimeEvolutionParameters(params, expvals, progr, mcsolve_params)
232+
p = TimeEvolutionParameters(params, expvals, mcsolve_params)
233233

234234
return sesolveProblem(H_eff_evo, ψ0, tlist; params = p, kwargs3...)
235235
end
@@ -330,7 +330,7 @@ function mcsolveEnsembleProblem(
330330
output_func::Union{Tuple,Nothing} = nothing,
331331
kwargs...,
332332
) where {DT1,DT2,TJC<:LindbladJumpCallbackType}
333-
_prob_func = prob_func isa Nothing ? _mcsolve_dispatch_prob_func(rng, ntraj) : prob_func
333+
_prob_func = prob_func isa Nothing ? _mcsolve_dispatch_prob_func(rng, ntraj, tlist) : prob_func
334334
_output_func =
335335
output_func isa Nothing ? _mcsolve_dispatch_output_func(ensemble_method, progress_bar, ntraj) : output_func
336336

src/time_evolution/sesolve.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,11 @@ function _merge_sesolve_kwargs_with_callback(kwargs, cb)
99
end
1010

1111
# Multiple dispatch depending on the progress_bar and e_ops types
12-
function _generate_sesolve_kwargs(e_ops, progress_bar::Val{true}, tlist, kwargs)
13-
cb = _generate_sesolve_callback(e_ops, tlist)
14-
return _merge_sesolve_kwargs_with_callback(kwargs, cb)
15-
end
16-
17-
function _generate_sesolve_kwargs(e_ops, progress_bar::Val{false}, tlist, kwargs)
18-
if e_ops isa Nothing
19-
return kwargs
20-
end
21-
cb = _generate_sesolve_callback(e_ops, tlist)
12+
function _generate_sesolve_kwargs(e_ops, progress_bar, tlist, kwargs)
13+
cb = _generate_sesolve_callback(e_ops, tlist, progress_bar)
2214
return _merge_sesolve_kwargs_with_callback(kwargs, cb)
2315
end
16+
_generate_sesolve_kwargs(e_ops::Nothing, progress_bar::Val{false}, tlist, kwargs) = kwargs
2417

2518
_sesolve_make_U_QobjEvo(H::QuantumObjectEvolution{<:MatrixOperator}) =
2619
QobjEvo(MatrixOperator(-1im * H.data.A), dims = H.dims, type = Operator)
@@ -88,8 +81,6 @@ function sesolveProblem(
8881
ψ0 = sparse_to_dense(_CType(ψ0), get_data(ψ0)) # Convert it to dense vector with complex element type
8982
U = H_evo.data
9083

91-
progr = ProgressBar(length(tlist), enable = getVal(progress_bar))
92-
9384
if e_ops isa Nothing
9485
expvals = Array{ComplexF64}(undef, 0, length(tlist))
9586
is_empty_e_ops = true
@@ -99,14 +90,14 @@ function sesolveProblem(
9990
end
10091

10192
if params isa TimeEvolutionParameters
102-
(!getVal(progress_bar) && (e_ops isa Nothing)) || throw(
93+
(e_ops isa Nothing) || throw(
10394
ArgumentError(
104-
"The parameter `params` cannot be a TimeEvolutionParameters object when `e_ops` is not Nothing and `progress_bar` is true.",
95+
"The parameter `params` cannot be a TimeEvolutionParameters object when `e_ops` is not Nothing",
10596
),
10697
)
10798
end
10899

109-
p = params isa TimeEvolutionParameters ? params : TimeEvolutionParameters(params, expvals, progr)
100+
p = params isa TimeEvolutionParameters ? params : TimeEvolutionParameters(params, expvals)
110101

111102
saveat = is_empty_e_ops ? tlist : [tlist[end]]
112103
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)

src/time_evolution/time_evo_parameters.jl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
11
# This function should be implemented after Julia v1.12
2-
Base.@constprop :aggressive function _delete_field(a::NamedTuple{an}, field::Symbol) where {an}
3-
names = Base.diff_names(an, (field,))
4-
return NamedTuple{names}(a)
5-
end
6-
72
@doc raw"""
83
struct TimeEvolutionParameters
94
105
A Julia constructor for handling the parameters of the time evolution of quantum systems.
116
"""
12-
struct TimeEvolutionParameters{ParT,TE<:AbstractMatrix,PT<:ProgressBar,MCST}
7+
struct TimeEvolutionParameters{ParT,TE<:AbstractMatrix,MCST}
138
params::ParT
149
expvals::TE
15-
progr::PT
1610
mcsolve_params::MCST
1711
end
1812

19-
TimeEvolutionParameters(params, expvals, progr) = TimeEvolutionParameters(params, expvals, progr, nothing)
13+
TimeEvolutionParameters(params, expvals) = TimeEvolutionParameters(params, expvals, nothing)
2014

2115
#=
2216
By defining a custom `getproperty` method for the `TimeEvolutionParameters` struct, we can access the fields of `params` directly.
@@ -39,7 +33,7 @@ Base.getindex(obj::TimeEvolutionParameters, i::Int) = getindex(obj.params, i)
3933
Base.length(obj::TimeEvolutionParameters) = length(obj.params)
4034

4135
# function Base.merge(a::TimeEvolutionParameters, b::NamedTuple)
42-
# return TimeEvolutionParameters(merge(a.params, b), a.expvals, a.progr, a.mcsolve_params)
36+
# return TimeEvolutionParameters(merge(a.params, b), a.expvals, a.mcsolve_params)
4337
# end
4438

4539
########## Mark the struct as a SciMLStructure ##########
@@ -75,12 +69,12 @@ end
7569
function replace(::Tunable, p::TimeEvolutionParameters{ParT}, newbuffer) where {ParT<:NamedTuple}
7670
@assert length(newbuffer) == length(p.params)
7771
new_params = NamedTuple{keys(p.params)}(Tuple(newbuffer))
78-
return TimeEvolutionParameters(new_params, p.expvals, p.progr, p.mcsolve_params)
72+
return TimeEvolutionParameters(new_params, p.expvals, p.mcsolve_params)
7973
end
8074

8175
function replace(::Tunable, p::TimeEvolutionParameters{ParT}, newbuffer) where {ParT<:AbstractVector}
8276
@assert length(newbuffer) == length(p.params)
83-
return TimeEvolutionParameters(newbuffer, p.expvals, p.progr, p.mcsolve_params)
77+
return TimeEvolutionParameters(newbuffer, p.expvals, p.mcsolve_params)
8478
end
8579

8680
function replace!(::Tunable, p::TimeEvolutionParameters{ParT}, newbuffer) where {ParT<:AbstractVector}

src/time_evolution/time_evolution_dynamical.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ function _dsf_mcsolve_prob_func(prob, i, repeat)
609609
),
610610
)
611611

612-
p = TimeEvolutionParameters(prm, expvals, progr, mcsolve_params)
612+
p = TimeEvolutionParameters(prm, expvals, mcsolve_params)
613613

614614
f = deepcopy(prob.f.f)
615615

0 commit comments

Comments
 (0)