Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased](https://github.com/qutip/QuantumToolbox.jl/tree/main)

- Support for single `AbstractQuantumObject` in `sc_ops` for faster specific method in `ssesolve` and `smesolve`. ([#408])
- Change save callbacks from `PresetTimeCallback` to `FunctionCallingCallback`. ([#410])

## [v0.27.0]
Release date: 2025-02-14
Expand Down Expand Up @@ -143,3 +144,4 @@ Release date: 2024-11-13
[#404]: https://github.com/qutip/QuantumToolbox.jl/issues/404
[#405]: https://github.com/qutip/QuantumToolbox.jl/issues/405
[#408]: https://github.com/qutip/QuantumToolbox.jl/issues/408
[#410]: https://github.com/qutip/QuantumToolbox.jl/issues/410
2 changes: 1 addition & 1 deletion src/QuantumToolbox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import SciMLOperators:
concretize
import LinearSolve: LinearProblem, SciMLLinearSolveAlgorithm, KrylovJL_MINRES, KrylovJL_GMRES
import DiffEqBase: get_tstops
import DiffEqCallbacks: PeriodicCallback, PresetTimeCallback, TerminateSteadyState
import DiffEqCallbacks: PeriodicCallback, FunctionCallingCallback, FunctionCallingAffect, TerminateSteadyState
import OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm
import OrdinaryDiffEqTsit5: Tsit5
import DiffEqNoiseProcess: RealWienerProcess!, RealWienerProcess
Expand Down
18 changes: 9 additions & 9 deletions src/time_evolution/callback_helpers/callback_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ function _generate_save_callback(e_ops, tlist, progress_bar, method)

expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist))

_save_affect! = method(e_ops_data, progr, Ref(1), expvals)
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
_save_func = method(e_ops_data, progr, Ref(1), expvals)
return FunctionCallingCallback(_save_func, funcat = tlist)
end

function _generate_stochastic_save_callback(e_ops, sc_ops, tlist, store_measurement, progress_bar, method)
Expand All @@ -69,8 +69,8 @@ function _generate_stochastic_save_callback(e_ops, sc_ops, tlist, store_measurem
expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist))
m_expvals = getVal(store_measurement) ? Array{Float64}(undef, length(sc_ops), length(tlist) - 1) : nothing

_save_affect! = method(store_measurement, e_ops_data, m_ops_data, progr, Ref(1), expvals, m_expvals)
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
_save_func = method(store_measurement, e_ops_data, m_ops_data, progr, Ref(1), expvals, m_expvals)
return FunctionCallingCallback(_save_func, funcat = tlist)
end

##
Expand Down Expand Up @@ -98,20 +98,20 @@ function _get_m_expvals(integrator::AbstractODESolution, method::Type{SF}) where
if cb isa Nothing
return nothing
else
return cb.affect!.m_expvals
return cb.affect!.func.m_expvals
end
end

#=
With this function we extract the e_ops from the SaveFuncMCSolve `affect!` function of the callback of the integrator.
This callback can only be a PresetTimeCallback (DiscreteCallback).
This callback can only be a FunctionCallingCallback (DiscreteCallback).
=#
function _get_e_ops(integrator::AbstractODEIntegrator, method::Type{SF}) where {SF<:AbstractSaveFunc}
cb = _get_save_callback(integrator, method)
if cb isa Nothing
return nothing
else
return cb.affect!.e_ops
return cb.affect!.func.e_ops
end
end

Expand All @@ -121,7 +121,7 @@ function _get_expvals(sol::AbstractODESolution, method::Type{SF}) where {SF<:Abs
if cb isa Nothing
return nothing
else
return cb.affect!.expvals
return cb.affect!.func.expvals
end
end

Expand Down Expand Up @@ -151,7 +151,7 @@ function _get_save_callback(cb::CallbackSet, method::Type{SF}) where {SF<:Abstra
end
end
function _get_save_callback(cb::DiscreteCallback, ::Type{SF}) where {SF<:AbstractSaveFunc}
if typeof(cb.affect!) <: AbstractSaveFunc
if typeof(cb.affect!) <: FunctionCallingAffect && typeof(cb.affect!.func) <: AbstractSaveFunc
return cb
end
return nothing
Expand Down
30 changes: 15 additions & 15 deletions src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ struct SaveFuncMCSolve{TE,IT,TEXPV} <: AbstractSaveFunc
expvals::TEXPV
end

(f::SaveFuncMCSolve)(integrator) = _save_func_mcsolve(integrator, f.e_ops, f.iter, f.expvals)
(f::SaveFuncMCSolve)(u, t, integrator) = _save_func_mcsolve(u, integrator, f.e_ops, f.iter, f.expvals)

_get_save_callback_idx(cb, ::Type{SaveFuncMCSolve}) = _mcsolve_has_continuous_jump(cb) ? 1 : 2

Expand Down Expand Up @@ -52,10 +52,10 @@ end

##

function _save_func_mcsolve(integrator, e_ops, iter, expvals)
function _save_func_mcsolve(u, integrator, e_ops, iter, expvals)
cache_mc = _mc_get_jump_callback(integrator).affect!.cache_mc

copyto!(cache_mc, integrator.u)
copyto!(cache_mc, u)
normalize!(cache_mc)
ψ = cache_mc
_expect = op -> dot(ψ, op, ψ)
Expand Down Expand Up @@ -114,8 +114,8 @@ function _generate_mcsolve_kwargs(ψ0, T, e_ops, tlist, c_ops, jump_callback, rn
else
expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist))

_save_affect! = SaveFuncMCSolve(get_data.(e_ops), Ref(1), expvals)
cb2 = PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
_save_func = SaveFuncMCSolve(get_data.(e_ops), Ref(1), expvals)
cb2 = FunctionCallingCallback(_save_func, funcat = tlist)
kwargs2 =
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, cb2, kwargs.callback),)) :
merge(kwargs, (callback = CallbackSet(cb1, cb2),))
Expand Down Expand Up @@ -214,11 +214,11 @@ function _mcsolve_initialize_callbacks(cb::CallbackSet, tlist, traj_rng)

if _mcsolve_has_continuous_jump(cb)
idx = 1
if cb_discrete[idx].affect! isa SaveFuncMCSolve
e_ops = cb_discrete[idx].affect!.e_ops
expvals = similar(cb_discrete[idx].affect!.expvals)
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals)
cb_save = (PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)),)
if cb_discrete[idx].affect!.func isa SaveFuncMCSolve
e_ops = cb_discrete[idx].affect!.func.e_ops
expvals = similar(cb_discrete[idx].affect!.func.expvals)
_save_func = SaveFuncMCSolve(e_ops, Ref(1), expvals)
cb_save = (FunctionCallingCallback(_save_func, funcat = tlist),)
else
cb_save = ()
end
Expand All @@ -229,11 +229,11 @@ function _mcsolve_initialize_callbacks(cb::CallbackSet, tlist, traj_rng)
return CallbackSet((cb_jump, cb_continuous[2:end]...), (cb_save..., cb_discrete[2:end]...))
else
idx = 2
if cb_discrete[idx].affect! isa SaveFuncMCSolve
e_ops = cb_discrete[idx].affect!.e_ops
expvals = similar(cb_discrete[idx].affect!.expvals)
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals)
cb_save = (PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)),)
if cb_discrete[idx].affect!.func isa SaveFuncMCSolve
e_ops = cb_discrete[idx].affect!.func.e_ops
expvals = similar(cb_discrete[idx].affect!.func.expvals)
_save_func = SaveFuncMCSolve(e_ops, Ref(1), expvals)
cb_save = (FunctionCallingCallback(_save_func, funcat = tlist),)
else
cb_save = ()
end
Expand Down
10 changes: 5 additions & 5 deletions src/time_evolution/callback_helpers/mesolve_callback_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@
expvals::TEXPV
end

(f::SaveFuncMESolve)(integrator) = _save_func_mesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals)
(f::SaveFuncMESolve{Nothing})(integrator) = _save_func(integrator, f.progr)
(f::SaveFuncMESolve)(u, t, integrator) = _save_func_mesolve(u, integrator, f.e_ops, f.progr, f.iter, f.expvals)
(f::SaveFuncMESolve{Nothing})(u, t, integrator) = _save_func(integrator, f.progr)

Check warning on line 13 in src/time_evolution/callback_helpers/mesolve_callback_helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/time_evolution/callback_helpers/mesolve_callback_helpers.jl#L13

Added line #L13 was not covered by tests

_get_e_ops_data(e_ops, ::Type{SaveFuncMESolve}) = [_generate_mesolve_e_op(op) for op in e_ops] # Broadcasting generates type instabilities on Julia v1.10

##

# When e_ops is a list of operators
function _save_func_mesolve(integrator, e_ops, progr, iter, expvals)
function _save_func_mesolve(u, integrator, e_ops, progr, iter, expvals)
# This is equivalent to tr(op * ρ), when both are matrices.
# The advantage of using this convention is that We don't need
# to reshape u to make it a matrix, but we reshape the e_ops once.

ρ = integrator.u
ρ = u
_expect = op -> dot(op, ρ)
@. expvals[:, iter[]] = _expect(e_ops)
iter[] += 1
Expand All @@ -35,7 +35,7 @@
if cb isa Nothing
return nothing
else
cb.affect!.e_ops .= e_ops # Only works if e_ops is a Vector of operators
cb.affect!.func.e_ops .= e_ops # Only works if e_ops is a Vector of operators
return nothing
end
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
expvals::TEXPV
end

(f::SaveFuncSESolve)(integrator) = _save_func_sesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals)
(f::SaveFuncSESolve{Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both mesolve and sesolve
(f::SaveFuncSESolve)(u, t, integrator) = _save_func_sesolve(u, integrator, f.e_ops, f.progr, f.iter, f.expvals)
(f::SaveFuncSESolve{Nothing})(u, t, integrator) = _save_func(integrator, f.progr) # Common for both mesolve and sesolve

Check warning on line 13 in src/time_evolution/callback_helpers/sesolve_callback_helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/time_evolution/callback_helpers/sesolve_callback_helpers.jl#L13

Added line #L13 was not covered by tests

_get_e_ops_data(e_ops, ::Type{SaveFuncSESolve}) = get_data.(e_ops)

##

# When e_ops is a list of operators
function _save_func_sesolve(integrator, e_ops, progr, iter, expvals)
ψ = integrator.u
function _save_func_sesolve(u, integrator, e_ops, progr, iter, expvals)
ψ = u
_expect = op -> dot(ψ, op, ψ)
@. expvals[:, iter[]] = _expect(e_ops)
iter[] += 1
Expand Down
10 changes: 5 additions & 5 deletions src/time_evolution/callback_helpers/smesolve_callback_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
m_expvals::TMEXPV
end

(f::SaveFuncSMESolve)(integrator) =
_save_func_smesolve(integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals)
(f::SaveFuncSMESolve{false,Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both all solvers
(f::SaveFuncSMESolve)(u, t, integrator) =
_save_func_smesolve(u, integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals)
(f::SaveFuncSMESolve{false,Nothing})(u, t, integrator) = _save_func(integrator, f.progr) # Common for both all solvers

Check warning on line 25 in src/time_evolution/callback_helpers/smesolve_callback_helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/time_evolution/callback_helpers/smesolve_callback_helpers.jl#L25

Added line #L25 was not covered by tests

_get_e_ops_data(e_ops, ::Type{SaveFuncSMESolve}) = _get_e_ops_data(e_ops, SaveFuncMESolve)
_get_m_ops_data(sc_ops, ::Type{SaveFuncSMESolve}) =
Expand All @@ -31,12 +31,12 @@
##

# When e_ops is a list of operators
function _save_func_smesolve(integrator, e_ops, m_ops, progr, iter, expvals, m_expvals)
function _save_func_smesolve(u, integrator, e_ops, m_ops, progr, iter, expvals, m_expvals)
# This is equivalent to tr(op * ρ), when both are matrices.
# The advantage of using this convention is that We don't need
# to reshape u to make it a matrix, but we reshape the e_ops once.

ρ = integrator.u
ρ = u

_expect = op -> dot(op, ρ)

Expand Down
10 changes: 5 additions & 5 deletions src/time_evolution/callback_helpers/ssesolve_callback_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
m_expvals::TMEXPV
end

(f::SaveFuncSSESolve)(integrator) =
_save_func_ssesolve(integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals)
(f::SaveFuncSSESolve{false,Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both all solvers
(f::SaveFuncSSESolve)(u, t, integrator) =
_save_func_ssesolve(u, integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals)
(f::SaveFuncSSESolve{false,Nothing})(u, t, integrator) = _save_func(integrator, f.progr) # Common for both all solvers

Check warning on line 25 in src/time_evolution/callback_helpers/ssesolve_callback_helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/time_evolution/callback_helpers/ssesolve_callback_helpers.jl#L25

Added line #L25 was not covered by tests

_get_e_ops_data(e_ops, ::Type{SaveFuncSSESolve}) = get_data.(e_ops)
_get_m_ops_data(sc_ops, ::Type{SaveFuncSSESolve}) = map(op -> Hermitian(get_data(op) + get_data(op)'), sc_ops)
Expand All @@ -32,8 +32,8 @@
##

# When e_ops is a list of operators
function _save_func_ssesolve(integrator, e_ops, m_ops, progr, iter, expvals, m_expvals)
ψ = integrator.u
function _save_func_ssesolve(u, integrator, e_ops, m_ops, progr, iter, expvals, m_expvals)
ψ = u

_expect = op -> dot(ψ, op, ψ)

Expand Down
Loading