Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased](https://github.com/qutip/QuantumToolbox.jl/tree/main)

- Support for single `AbstractQuantumObject` in `sc_ops` for faster specific method in `ssesolve` and `smesolve`. ([#408])
- Change save callbacks from `PresetTimeCallback` to `FunctionCallingCallback`. ([#410])
- Align `eigenstates` and `eigenenergies` to QuTiP. ([#411])

## [v0.27.0]
Expand Down Expand Up @@ -144,4 +145,5 @@ Release date: 2024-11-13
[#404]: https://github.com/qutip/QuantumToolbox.jl/issues/404
[#405]: https://github.com/qutip/QuantumToolbox.jl/issues/405
[#408]: https://github.com/qutip/QuantumToolbox.jl/issues/408
[#410]: https://github.com/qutip/QuantumToolbox.jl/issues/410
[#411]: https://github.com/qutip/QuantumToolbox.jl/issues/411
2 changes: 1 addition & 1 deletion src/QuantumToolbox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import SciMLOperators:
concretize
import LinearSolve: LinearProblem, SciMLLinearSolveAlgorithm, KrylovJL_MINRES, KrylovJL_GMRES
import DiffEqBase: get_tstops
import DiffEqCallbacks: PeriodicCallback, PresetTimeCallback, TerminateSteadyState
import DiffEqCallbacks: PeriodicCallback, FunctionCallingCallback, FunctionCallingAffect, TerminateSteadyState
import OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm
import OrdinaryDiffEqTsit5: Tsit5
import DiffEqNoiseProcess: RealWienerProcess!, RealWienerProcess
Expand Down
18 changes: 9 additions & 9 deletions src/time_evolution/callback_helpers/callback_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ function _generate_save_callback(e_ops, tlist, progress_bar, method)

expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist))

_save_affect! = method(e_ops_data, progr, Ref(1), expvals)
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
_save_func = method(e_ops_data, progr, Ref(1), expvals)
return FunctionCallingCallback(_save_func, funcat = tlist)
end

function _generate_stochastic_save_callback(e_ops, sc_ops, tlist, store_measurement, progress_bar, method)
Expand All @@ -69,8 +69,8 @@ function _generate_stochastic_save_callback(e_ops, sc_ops, tlist, store_measurem
expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist))
m_expvals = getVal(store_measurement) ? Array{Float64}(undef, length(sc_ops), length(tlist) - 1) : nothing

_save_affect! = method(store_measurement, e_ops_data, m_ops_data, progr, Ref(1), expvals, m_expvals)
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
_save_func = method(store_measurement, e_ops_data, m_ops_data, progr, Ref(1), expvals, m_expvals)
return FunctionCallingCallback(_save_func, funcat = tlist)
end

##
Expand Down Expand Up @@ -98,20 +98,20 @@ function _get_m_expvals(integrator::AbstractODESolution, method::Type{SF}) where
if cb isa Nothing
return nothing
else
return cb.affect!.m_expvals
return cb.affect!.func.m_expvals
end
end

#=
With this function we extract the e_ops from the SaveFuncMCSolve `affect!` function of the callback of the integrator.
This callback can only be a PresetTimeCallback (DiscreteCallback).
This callback can only be a FunctionCallingCallback (DiscreteCallback).
=#
function _get_e_ops(integrator::AbstractODEIntegrator, method::Type{SF}) where {SF<:AbstractSaveFunc}
cb = _get_save_callback(integrator, method)
if cb isa Nothing
return nothing
else
return cb.affect!.e_ops
return cb.affect!.func.e_ops
end
end

Expand All @@ -121,7 +121,7 @@ function _get_expvals(sol::AbstractODESolution, method::Type{SF}) where {SF<:Abs
if cb isa Nothing
return nothing
else
return cb.affect!.expvals
return cb.affect!.func.expvals
end
end

Expand Down Expand Up @@ -151,7 +151,7 @@ function _get_save_callback(cb::CallbackSet, method::Type{SF}) where {SF<:Abstra
end
end
function _get_save_callback(cb::DiscreteCallback, ::Type{SF}) where {SF<:AbstractSaveFunc}
if typeof(cb.affect!) <: AbstractSaveFunc
if typeof(cb.affect!) <: FunctionCallingAffect && typeof(cb.affect!.func) <: AbstractSaveFunc
return cb
end
return nothing
Expand Down
30 changes: 15 additions & 15 deletions src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ struct SaveFuncMCSolve{TE,IT,TEXPV} <: AbstractSaveFunc
expvals::TEXPV
end

(f::SaveFuncMCSolve)(integrator) = _save_func_mcsolve(integrator, f.e_ops, f.iter, f.expvals)
(f::SaveFuncMCSolve)(u, t, integrator) = _save_func_mcsolve(u, integrator, f.e_ops, f.iter, f.expvals)

_get_save_callback_idx(cb, ::Type{SaveFuncMCSolve}) = _mcsolve_has_continuous_jump(cb) ? 1 : 2

Expand Down Expand Up @@ -52,10 +52,10 @@ end

##

function _save_func_mcsolve(integrator, e_ops, iter, expvals)
function _save_func_mcsolve(u, integrator, e_ops, iter, expvals)
cache_mc = _mc_get_jump_callback(integrator).affect!.cache_mc

copyto!(cache_mc, integrator.u)
copyto!(cache_mc, u)
normalize!(cache_mc)
ψ = cache_mc
_expect = op -> dot(ψ, op, ψ)
Expand Down Expand Up @@ -114,8 +114,8 @@ function _generate_mcsolve_kwargs(ψ0, T, e_ops, tlist, c_ops, jump_callback, rn
else
expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist))

_save_affect! = SaveFuncMCSolve(get_data.(e_ops), Ref(1), expvals)
cb2 = PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
_save_func = SaveFuncMCSolve(get_data.(e_ops), Ref(1), expvals)
cb2 = FunctionCallingCallback(_save_func, funcat = tlist)
kwargs2 =
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, cb2, kwargs.callback),)) :
merge(kwargs, (callback = CallbackSet(cb1, cb2),))
Expand Down Expand Up @@ -214,11 +214,11 @@ function _mcsolve_initialize_callbacks(cb::CallbackSet, tlist, traj_rng)

if _mcsolve_has_continuous_jump(cb)
idx = 1
if cb_discrete[idx].affect! isa SaveFuncMCSolve
e_ops = cb_discrete[idx].affect!.e_ops
expvals = similar(cb_discrete[idx].affect!.expvals)
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals)
cb_save = (PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)),)
if cb_discrete[idx].affect!.func isa SaveFuncMCSolve
e_ops = cb_discrete[idx].affect!.func.e_ops
expvals = similar(cb_discrete[idx].affect!.func.expvals)
_save_func = SaveFuncMCSolve(e_ops, Ref(1), expvals)
cb_save = (FunctionCallingCallback(_save_func, funcat = tlist),)
else
cb_save = ()
end
Expand All @@ -229,11 +229,11 @@ function _mcsolve_initialize_callbacks(cb::CallbackSet, tlist, traj_rng)
return CallbackSet((cb_jump, cb_continuous[2:end]...), (cb_save..., cb_discrete[2:end]...))
else
idx = 2
if cb_discrete[idx].affect! isa SaveFuncMCSolve
e_ops = cb_discrete[idx].affect!.e_ops
expvals = similar(cb_discrete[idx].affect!.expvals)
_save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals)
cb_save = (PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)),)
if cb_discrete[idx].affect!.func isa SaveFuncMCSolve
e_ops = cb_discrete[idx].affect!.func.e_ops
expvals = similar(cb_discrete[idx].affect!.func.expvals)
_save_func = SaveFuncMCSolve(e_ops, Ref(1), expvals)
cb_save = (FunctionCallingCallback(_save_func, funcat = tlist),)
else
cb_save = ()
end
Expand Down
10 changes: 5 additions & 5 deletions src/time_evolution/callback_helpers/mesolve_callback_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@
expvals::TEXPV
end

(f::SaveFuncMESolve)(integrator) = _save_func_mesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals)
(f::SaveFuncMESolve{Nothing})(integrator) = _save_func(integrator, f.progr)
(f::SaveFuncMESolve)(u, t, integrator) = _save_func_mesolve(u, integrator, f.e_ops, f.progr, f.iter, f.expvals)
(f::SaveFuncMESolve{Nothing})(u, t, integrator) = _save_func(integrator, f.progr)

Check warning on line 13 in src/time_evolution/callback_helpers/mesolve_callback_helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/time_evolution/callback_helpers/mesolve_callback_helpers.jl#L13

Added line #L13 was not covered by tests

_get_e_ops_data(e_ops, ::Type{SaveFuncMESolve}) = [_generate_mesolve_e_op(op) for op in e_ops] # Broadcasting generates type instabilities on Julia v1.10

##

# When e_ops is a list of operators
function _save_func_mesolve(integrator, e_ops, progr, iter, expvals)
function _save_func_mesolve(u, integrator, e_ops, progr, iter, expvals)
# This is equivalent to tr(op * ρ), when both are matrices.
# The advantage of using this convention is that We don't need
# to reshape u to make it a matrix, but we reshape the e_ops once.

ρ = integrator.u
ρ = u
_expect = op -> dot(op, ρ)
@. expvals[:, iter[]] = _expect(e_ops)
iter[] += 1
Expand All @@ -35,7 +35,7 @@
if cb isa Nothing
return nothing
else
cb.affect!.e_ops .= e_ops # Only works if e_ops is a Vector of operators
cb.affect!.func.e_ops .= e_ops # Only works if e_ops is a Vector of operators
return nothing
end
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
expvals::TEXPV
end

(f::SaveFuncSESolve)(integrator) = _save_func_sesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals)
(f::SaveFuncSESolve{Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both mesolve and sesolve
(f::SaveFuncSESolve)(u, t, integrator) = _save_func_sesolve(u, integrator, f.e_ops, f.progr, f.iter, f.expvals)
(f::SaveFuncSESolve{Nothing})(u, t, integrator) = _save_func(integrator, f.progr) # Common for both mesolve and sesolve

Check warning on line 13 in src/time_evolution/callback_helpers/sesolve_callback_helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/time_evolution/callback_helpers/sesolve_callback_helpers.jl#L13

Added line #L13 was not covered by tests

_get_e_ops_data(e_ops, ::Type{SaveFuncSESolve}) = get_data.(e_ops)

##

# When e_ops is a list of operators
function _save_func_sesolve(integrator, e_ops, progr, iter, expvals)
ψ = integrator.u
function _save_func_sesolve(u, integrator, e_ops, progr, iter, expvals)
ψ = u
_expect = op -> dot(ψ, op, ψ)
@. expvals[:, iter[]] = _expect(e_ops)
iter[] += 1
Expand Down
10 changes: 5 additions & 5 deletions src/time_evolution/callback_helpers/smesolve_callback_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
m_expvals::TMEXPV
end

(f::SaveFuncSMESolve)(integrator) =
_save_func_smesolve(integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals)
(f::SaveFuncSMESolve{false,Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both all solvers
(f::SaveFuncSMESolve)(u, t, integrator) =
_save_func_smesolve(u, integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals)
(f::SaveFuncSMESolve{false,Nothing})(u, t, integrator) = _save_func(integrator, f.progr) # Common for both all solvers

Check warning on line 25 in src/time_evolution/callback_helpers/smesolve_callback_helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/time_evolution/callback_helpers/smesolve_callback_helpers.jl#L25

Added line #L25 was not covered by tests

_get_e_ops_data(e_ops, ::Type{SaveFuncSMESolve}) = _get_e_ops_data(e_ops, SaveFuncMESolve)
_get_m_ops_data(sc_ops, ::Type{SaveFuncSMESolve}) =
Expand All @@ -31,12 +31,12 @@
##

# When e_ops is a list of operators
function _save_func_smesolve(integrator, e_ops, m_ops, progr, iter, expvals, m_expvals)
function _save_func_smesolve(u, integrator, e_ops, m_ops, progr, iter, expvals, m_expvals)
# This is equivalent to tr(op * ρ), when both are matrices.
# The advantage of using this convention is that We don't need
# to reshape u to make it a matrix, but we reshape the e_ops once.

ρ = integrator.u
ρ = u

_expect = op -> dot(op, ρ)

Expand Down
10 changes: 5 additions & 5 deletions src/time_evolution/callback_helpers/ssesolve_callback_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
m_expvals::TMEXPV
end

(f::SaveFuncSSESolve)(integrator) =
_save_func_ssesolve(integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals)
(f::SaveFuncSSESolve{false,Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both all solvers
(f::SaveFuncSSESolve)(u, t, integrator) =
_save_func_ssesolve(u, integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals)
(f::SaveFuncSSESolve{false,Nothing})(u, t, integrator) = _save_func(integrator, f.progr) # Common for both all solvers

Check warning on line 25 in src/time_evolution/callback_helpers/ssesolve_callback_helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/time_evolution/callback_helpers/ssesolve_callback_helpers.jl#L25

Added line #L25 was not covered by tests

_get_e_ops_data(e_ops, ::Type{SaveFuncSSESolve}) = get_data.(e_ops)
_get_m_ops_data(sc_ops, ::Type{SaveFuncSSESolve}) = map(op -> Hermitian(get_data(op) + get_data(op)'), sc_ops)
Expand All @@ -32,8 +32,8 @@
##

# When e_ops is a list of operators
function _save_func_ssesolve(integrator, e_ops, m_ops, progr, iter, expvals, m_expvals)
ψ = integrator.u
function _save_func_ssesolve(u, integrator, e_ops, m_ops, progr, iter, expvals, m_expvals)
ψ = u

_expect = op -> dot(ψ, op, ψ)

Expand Down
Loading