Skip to content

Commit aab15d9

Browse files
[no ci] Make dWdt in-place
1 parent 9d7f6b3 commit aab15d9

File tree

3 files changed

+18
-12
lines changed

3 files changed

+18
-12
lines changed

src/time_evolution/callback_helpers/callback_helpers.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ function _generate_stochastic_save_callback(e_ops, sc_ops, tlist, store_measurem
7474
expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist))
7575
m_expvals = getVal(store_measurement) ? Array{Float64}(undef, length(sc_ops), length(tlist) - 1) : nothing
7676

77-
_save_func = method(store_measurement, e_ops_data, m_ops_data, progr, Ref(1), expvals, m_expvals, tlist)
77+
_save_func_cache = Array{Float64}(undef, length(sc_ops))
78+
_save_func =
79+
method(store_measurement, e_ops_data, m_ops_data, progr, Ref(1), expvals, m_expvals, tlist, _save_func_cache)
7880
return FunctionCallingCallback(_save_func, funcat = tlist)
7981
end
8082

@@ -169,11 +171,11 @@ _get_save_callback_idx(cb, method) = 1
169171

170172
# TODO: Add some cache mechanism to avoid memory allocations
171173
# TODO: To improve. See https://github.com/SciML/DiffEqNoiseProcess.jl/issues/214
172-
function _homodyne_dWdt(integrator, tlist, iter)
174+
function _homodyne_dWdt!(dWdt_cache, integrator, tlist, iter)
173175
idx = findfirst(>=(tlist[iter[]-1]), integrator.W.t)
174176

175177
# We are assuming that the last element is tlist[iter[]]
176-
@inbounds _dWdt = (integrator.W.u[end] .- integrator.W.u[idx]) ./ (integrator.W.t[end] - integrator.W.t[idx])
178+
@inbounds dWdt_cache .= (integrator.W.u[end] .- integrator.W.u[idx]) ./ (integrator.W.t[end] - integrator.W.t[idx])
177179

178-
return _dWdt
180+
return nothing
179181
end

src/time_evolution/callback_helpers/smesolve_callback_helpers.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ struct SaveFuncSMESolve{
1111
TEXPV<:Union{Nothing,AbstractMatrix},
1212
TMEXPV<:Union{Nothing,AbstractMatrix},
1313
TLT<:AbstractVector,
14+
CT<:AbstractVector,
1415
} <: AbstractSaveFunc
1516
store_measurement::Val{SM}
1617
e_ops::TE
@@ -20,10 +21,11 @@ struct SaveFuncSMESolve{
2021
expvals::TEXPV
2122
m_expvals::TMEXPV
2223
tlist::TLT
24+
dWdt_cache::CT
2325
end
2426

2527
(f::SaveFuncSMESolve)(u, t, integrator) =
26-
_save_func_smesolve(u, integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals, f.tlist)
28+
_save_func_smesolve(u, integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals, f.tlist, f.dWdt_cache)
2729
(f::SaveFuncSMESolve{false,Nothing})(u, t, integrator) = _save_func(integrator, f.progr) # Common for both all solvers
2830

2931
_get_e_ops_data(e_ops, ::Type{SaveFuncSMESolve}) = _get_e_ops_data(e_ops, SaveFuncMESolve)
@@ -33,7 +35,7 @@ _get_m_ops_data(sc_ops, ::Type{SaveFuncSMESolve}) =
3335
##
3436

3537
# When e_ops is a list of operators
36-
function _save_func_smesolve(u, integrator, e_ops, m_ops, progr, iter, expvals, m_expvals, tlist)
38+
function _save_func_smesolve(u, integrator, e_ops, m_ops, progr, iter, expvals, m_expvals, tlist, dWdt_cache)
3739
# This is equivalent to tr(op * ρ), when both are matrices.
3840
# The advantage of using this convention is that We don't need
3941
# to reshape u to make it a matrix, but we reshape the e_ops once.
@@ -47,8 +49,8 @@ function _save_func_smesolve(u, integrator, e_ops, m_ops, progr, iter, expvals,
4749
end
4850

4951
if !isnothing(m_expvals) && iter[] > 1
50-
_dWdt = _homodyne_dWdt(integrator, tlist, iter)
51-
@. m_expvals[:, iter[]-1] = real(_expect(m_ops)) + _dWdt
52+
_homodyne_dWdt!(dWdt_cache, integrator, tlist, iter)
53+
@. m_expvals[:, iter[]-1] = real(_expect(m_ops)) + dWdt_cache
5254
end
5355

5456
iter[] += 1

src/time_evolution/callback_helpers/ssesolve_callback_helpers.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ struct SaveFuncSSESolve{
1111
TEXPV<:Union{Nothing,AbstractMatrix},
1212
TMEXPV<:Union{Nothing,AbstractMatrix},
1313
TLT<:AbstractVector,
14+
CT<:AbstractVector,
1415
} <: AbstractSaveFunc
1516
store_measurement::Val{SM}
1617
e_ops::TE
@@ -20,10 +21,11 @@ struct SaveFuncSSESolve{
2021
expvals::TEXPV
2122
m_expvals::TMEXPV
2223
tlist::TLT
24+
dWdt_cache::CT
2325
end
2426

2527
(f::SaveFuncSSESolve)(u, t, integrator) =
26-
_save_func_ssesolve(u, integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals, f.tlist)
28+
_save_func_ssesolve(u, integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals, f.tlist, f.dWdt_cache)
2729
(f::SaveFuncSSESolve{false,Nothing})(u, t, integrator) = _save_func(integrator, f.progr) # Common for both all solvers
2830

2931
_get_e_ops_data(e_ops, ::Type{SaveFuncSSESolve}) = get_data.(e_ops)
@@ -34,7 +36,7 @@ _get_save_callback_idx(cb, ::Type{SaveFuncSSESolve}) = 2 # The first one is the
3436
##
3537

3638
# When e_ops is a list of operators
37-
function _save_func_ssesolve(u, integrator, e_ops, m_ops, progr, iter, expvals, m_expvals, tlist)
39+
function _save_func_ssesolve(u, integrator, e_ops, m_ops, progr, iter, expvals, m_expvals, tlist, dWdt_cache)
3840
ψ = u
3941

4042
_expect = op -> dot(ψ, op, ψ)
@@ -44,8 +46,8 @@ function _save_func_ssesolve(u, integrator, e_ops, m_ops, progr, iter, expvals,
4446
end
4547

4648
if !isnothing(m_expvals) && iter[] > 1
47-
_dWdt = _homodyne_dWdt(integrator, tlist, iter)
48-
@. m_expvals[:, iter[]-1] = real(_expect(m_ops)) + _dWdt
49+
_homodyne_dWdt!(dWdt_cache, integrator, tlist, iter)
50+
@. m_expvals[:, iter[]-1] = real(_expect(m_ops)) + dWdt_cache
4951
end
5052

5153
iter[] += 1

0 commit comments

Comments
 (0)