Skip to content

Commit eb05e53

Browse files
Fix dsf_mcsolve
1 parent 18549c1 commit eb05e53

File tree

3 files changed

+140
-46
lines changed

3 files changed

+140
-46
lines changed

src/QuantumToolbox.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ import SciMLBase:
3535
CallbackSet,
3636
ContinuousCallback,
3737
DiscreteCallback,
38-
AbstractSciMLProblem
38+
AbstractSciMLProblem,
39+
AbstractODEIntegrator
3940
import StochasticDiffEq: StochasticDiffEqAlgorithm, SRA1
4041
import SciMLOperators:
4142
SciMLOperators,

src/time_evolution/callback_helpers.jl

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ end
1212
(f::SaveFuncSESolve)(integrator) = _save_func_sesolve(integrator, f.e_ops, f.is_empty_e_ops)
1313
(f::SaveFuncSESolve{Nothing})(integrator) = _save_func_sesolve(integrator)
1414

15+
##
16+
1517
# When e_ops is Nothing
1618
function _save_func_sesolve(integrator)
1719
next!(integrator.p.progr)
@@ -25,15 +27,15 @@ function _save_func_sesolve(integrator, e_ops, is_empty_e_ops)
2527
progr = integrator.p.progr
2628
if !is_empty_e_ops
2729
ψ = integrator.u
28-
_expect = op -> dot(ψ, get_data(op), ψ)
30+
_expect = op -> dot(ψ, op, ψ)
2931
@. expvals[:, progr.counter[]+1] = _expect(e_ops)
3032
end
3133
return _save_func_sesolve(integrator)
3234
end
3335

3436
function _generate_sesolve_callback(e_ops, tlist)
3537
is_empty_e_ops = e_ops isa Nothing ? true : isempty(e_ops)
36-
_save_affect! = SaveFuncSESolve(e_ops, is_empty_e_ops)
38+
_save_affect! = SaveFuncSESolve(get_data.(e_ops), is_empty_e_ops)
3739
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
3840
end
3941

@@ -46,12 +48,12 @@ end
4648

4749
(f::SaveFuncMCSolve)(integrator) = _save_func_mcsolve(integrator, f.e_ops, f.is_empty_e_ops)
4850

49-
struct LindbladJumpAffect!{T1,T2}
51+
struct LindbladJump{T1,T2}
5052
c_ops::T1
5153
c_ops_herm::T2
5254
end
5355

54-
(f::LindbladJumpAffect!)(integrator) = _lindblad_jump_affect!(integrator, f.c_ops, f.c_ops_herm)
56+
(f::LindbladJump)(integrator) = _lindblad_jump_affect!(integrator, f.c_ops, f.c_ops_herm)
5557

5658
##
5759

@@ -63,7 +65,7 @@ function _save_func_mcsolve(integrator, e_ops, is_empty_e_ops)
6365
copyto!(cache_mc, integrator.u)
6466
normalize!(cache_mc)
6567
ψ = cache_mc
66-
_expect = op -> dot(ψ, get_data(op), ψ)
68+
_expect = op -> dot(ψ, op, ψ)
6769
@. expvals[:, progr.counter[]+1] = _expect(e_ops)
6870
end
6971
next!(progr)
@@ -75,7 +77,7 @@ function _generate_mcsolve_kwargs(e_ops, tlist, c_ops, jump_callback, kwargs)
7577
c_ops_data = get_data.(c_ops)
7678
c_ops_herm_data = map(op -> op' * op, c_ops_data)
7779

78-
_affect! = LindbladJumpAffect!(c_ops_data, c_ops_herm_data)
80+
_affect! = LindbladJump(c_ops_data, c_ops_herm_data)
7981

8082
if jump_callback isa DiscreteLindbladJumpCallback
8183
cb1 = DiscreteCallback(_mcsolve_discrete_condition, _affect!, save_positions = (false, false))
@@ -96,8 +98,8 @@ function _generate_mcsolve_kwargs(e_ops, tlist, c_ops, jump_callback, kwargs)
9698
return kwargs2
9799
else
98100
is_empty_e_ops = isempty(e_ops)
99-
_save_affect! = SaveFuncMCSolve(e_ops, is_empty_e_ops)
100-
cb2 = PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
101+
_save_affect! = SaveFuncMCSolve(get_data.(e_ops), is_empty_e_ops)
102+
cb2 = _PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
101103
kwargs2 =
102104
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, cb2, kwargs.callback),)) :
103105
merge(kwargs, (callback = CallbackSet(cb1, cb2),))
@@ -146,3 +148,74 @@ _mcsolve_continuous_condition(u, t, integrator) =
146148

147149
_mcsolve_discrete_condition(u, t, integrator) =
148150
@inbounds real(dot(u, u)) < real(integrator.p.mcsolve_params.random_n[1])
151+
152+
#=
153+
With this function we extract the c_ops and c_ops_herm from the LindbladJump `affect!` function of the callback of the integrator.
154+
This callback can be a DiscreteLindbladJumpCallback or a ContinuousLindbladJumpCallback.
155+
=#
156+
function _mcsolve_get_c_ops(integrator::AbstractODEIntegrator)
157+
cb_set = integrator.opts.callback # This is supposed to be a CallbackSet
158+
(cb_set isa CallbackSet) || throw(ArgumentError("The callback must be a CallbackSet."))
159+
cb = isempty(cb_set.continuous_callbacks) ? cb_set.discrete_callback[1] : cb_set.continuous_callbacks[1]
160+
return cb.affect!.c_ops, cb.affect!.c_ops_herm
161+
end
162+
163+
#=
164+
With this function we extract the e_ops from the SaveFuncMCSolve `affect!` function of the callback of the integrator.
165+
This callback can only be a PresetTimeCallback (DiscreteCallback).
166+
=#
167+
function _mcsolve_get_e_ops(integrator::AbstractODEIntegrator)
168+
cb_set = integrator.opts.callback # This is supposed to be a CallbackSet
169+
(cb_set isa CallbackSet) || throw(ArgumentError("The callback must be a CallbackSet."))
170+
cb = length(cb_set.continuous_callbacks) > 0 ? cb_set.discrete_callbacks[1] : cb_set.discrete_callbacks[2]
171+
return cb.affect!.e_ops
172+
end
173+
174+
## Temporary function to avoid errors. Waiting for the PR In DiffEqCallbacks.jl to be merged.
175+
176+
import SciMLBase: INITIALIZE_DEFAULT, add_tstop!
177+
178+
function _PresetTimeCallback(
179+
tstops,
180+
user_affect!;
181+
initialize = INITIALIZE_DEFAULT,
182+
filter_tstops = true,
183+
sort_inplace = false,
184+
kwargs...,
185+
)
186+
if !(tstops isa AbstractVector) && !(tstops isa Number)
187+
throw(ArgumentError("tstops must either be a number or a Vector. Was $tstops"))
188+
end
189+
190+
tstops = tstops isa Number ? [tstops] : (sort_inplace ? sort!(tstops) : sort(tstops))
191+
192+
condition = let
193+
function (u, t, integrator)
194+
if hasproperty(integrator, :dt)
195+
insorted(t, tstops) && (integrator.t - integrator.dt) != integrator.t
196+
else
197+
insorted(t, tstops)
198+
end
199+
end
200+
end
201+
202+
# Initialization: first call to `f` should be *before* any time steps have been taken:
203+
initialize_preset = function (c, u, t, integrator)
204+
initialize(c, u, t, integrator)
205+
206+
if filter_tstops
207+
tdir = integrator.tdir
208+
tspan = integrator.sol.prob.tspan
209+
_tstops = tstops[@. tdir * tspan[1] < tdir * tstops < tdir * tspan[2]]
210+
else
211+
_tstops = tstops
212+
end
213+
for tstop in _tstops
214+
add_tstop!(integrator, tstop)
215+
end
216+
if t tstops
217+
user_affect!(integrator)
218+
end
219+
end
220+
return DiscreteCallback(condition, user_affect!; initialize = initialize_preset, kwargs...)
221+
end

src/time_evolution/time_evolution_dynamical.jl

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -489,9 +489,9 @@ end
489489
# Dynamical Shifted Fock mcsolve
490490

491491
function _DSF_mcsolve_Condition(u, t, integrator)
492-
internal_params = integrator.p
493-
op_list = internal_params.op_list
494-
δα_list = internal_params.δα_list
492+
params = integrator.p
493+
op_list = params.op_list
494+
δα_list = params.δα_list
495495

496496
ψt = u
497497

@@ -508,20 +508,24 @@ function _DSF_mcsolve_Condition(u, t, integrator)
508508
end
509509

510510
function _DSF_mcsolve_Affect!(integrator)
511-
internal_params = integrator.p
512-
op_list = internal_params.op_list
513-
αt_list = internal_params.αt_list
514-
δα_list = internal_params.δα_list
515-
H = internal_params.H_fun
516-
c_ops = internal_params.c_ops_fun
517-
e_ops = internal_params.e_ops_fun
518-
e_ops0 = internal_params.e_ops_mc
519-
c_ops0 = internal_params.c_ops
520-
ψt = internal_params.dsf_cache1
521-
dsf_cache = internal_params.dsf_cache2
522-
expv_cache = internal_params.expv_cache
523-
dsf_params = internal_params.dsf_params
524-
dsf_displace_cache_full = internal_params.dsf_displace_cache_full
511+
params = integrator.p
512+
op_list = params.op_list
513+
αt_list = params.αt_list
514+
δα_list = params.δα_list
515+
H = params.H_fun
516+
c_ops = params.c_ops_fun
517+
e_ops = params.e_ops_fun
518+
ψt = params.dsf_cache1
519+
dsf_cache = params.dsf_cache2
520+
expv_cache = params.expv_cache
521+
dsf_params = params.dsf_params
522+
dsf_displace_cache_full = params.dsf_displace_cache_full
523+
524+
# e_ops0 = params.e_ops
525+
# c_ops0 = params.c_ops
526+
527+
e_ops0 = _mcsolve_get_e_ops(integrator)
528+
c_ops0, c_ops0_herm = _mcsolve_get_c_ops(integrator)
525529

526530
copyto!(ψt, integrator.u)
527531
normalize!(ψt)
@@ -561,42 +565,58 @@ function _DSF_mcsolve_Affect!(integrator)
561565
op_l2 = op_list .+ αt_list
562566
e_ops2 = e_ops(op_l2, dsf_params)
563567
c_ops2 = c_ops(op_l2, dsf_params)
568+
569+
## By copying the data, we are assuming that the variables are Vectors and not Tuple
564570
@. e_ops0 = get_data(e_ops2)
565571
@. c_ops0 = get_data(c_ops2)
566-
H_nh = lmul!(convert(eltype(ψt), 0.5im), mapreduce(op -> op' * op, +, c_ops0))
572+
c_ops0_herm .= map(op -> op' * op, c_ops0)
573+
574+
H_nh = convert(eltype(ψt), 0.5im) * sum(c_ops0_herm)
567575
# By doing this, we are assuming that the system is time-independent and f is a MatrixOperator
568576
copyto!(integrator.f.f.A, lmul!(-1im, H(op_l2, dsf_params).data - H_nh))
569577
return u_modified!(integrator, true)
570578
end
571579

572580
function _dsf_mcsolve_prob_func(prob, i, repeat)
573-
internal_params = prob.p
581+
params = prob.p
582+
583+
expvals = similar(params.expvals)
584+
progr = ProgressBar(size(expvals, 2), enable = false)
585+
586+
T = eltype(expvals)
587+
588+
mcsolve_params = merge(
589+
params.mcsolve_params,
590+
(
591+
random_n = T[rand()],
592+
cache_mc = similar(params.mcsolve_params.cache_mc),
593+
weights_mc = similar(params.mcsolve_params.weights_mc),
594+
cumsum_weights_mc = similar(params.mcsolve_params.weights_mc),
595+
jump_times = similar(params.mcsolve_params.jump_times),
596+
jump_which = similar(params.mcsolve_params.jump_which),
597+
jump_times_which_idx = T[1],
598+
),
599+
)
574600

575601
prm = merge(
576-
internal_params,
602+
params.params,
577603
(
578-
e_ops_mc = deepcopy(internal_params.e_ops_mc),
579-
c_ops = deepcopy(internal_params.c_ops),
580-
expvals = similar(internal_params.expvals),
581-
cache_mc = similar(internal_params.cache_mc),
582-
weights_mc = similar(internal_params.weights_mc),
583-
cumsum_weights_mc = similar(internal_params.weights_mc),
584-
random_n = Ref(rand()),
585-
progr_mc = ProgressBar(size(internal_params.expvals, 2), enable = false),
586-
jump_times_which_idx = Ref(1),
587-
jump_times = similar(internal_params.jump_times),
588-
jump_which = similar(internal_params.jump_which),
589-
αt_list = copy(internal_params.αt_list),
590-
dsf_cache1 = similar(internal_params.dsf_cache1),
591-
dsf_cache2 = similar(internal_params.dsf_cache2),
592-
expv_cache = copy(internal_params.expv_cache),
593-
dsf_displace_cache_full = deepcopy(internal_params.dsf_displace_cache_full), # This brutally copies also the MatrixOperators, and it is inefficient.
604+
αt_list = copy(params.params.αt_list),
605+
dsf_cache1 = similar(params.params.dsf_cache1),
606+
dsf_cache2 = similar(params.params.dsf_cache2),
607+
expv_cache = copy(params.params.expv_cache),
608+
dsf_displace_cache_full = deepcopy(params.params.dsf_displace_cache_full), # This brutally copies also the MatrixOperators, and it is inefficient.
594609
),
595610
)
596611

612+
p = TimeEvolutionParameters(prm, expvals, progr, mcsolve_params)
613+
597614
f = deepcopy(prob.f.f)
598615

599-
return remake(prob, f = f, p = prm)
616+
# We need to deepcopy the callbacks because they contain the c_ops and e_ops, which are modified in the affect function
617+
cb = deepcopy(prob.kwargs[:callback])
618+
619+
return remake(prob, f = f, p = p, callback = cb)
600620
end
601621

602622
function dsf_mcsolveEnsembleProblem(

0 commit comments

Comments
 (0)