diff --git a/CHANGELOG.md b/CHANGELOG.md index 253449b48..b2dcd5593 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased](https://github.com/qutip/QuantumToolbox.jl/tree/main) - Support for single `AbstractQuantumObject` in `sc_ops` for faster specific method in `ssesolve` and `smesolve`. ([#408]) +- Change save callbacks from `PresetTimeCallback` to `FunctionCallingCallback`. ([#410]) - Align `eigenstates` and `eigenenergies` to QuTiP. ([#411]) ## [v0.27.0] @@ -144,4 +145,5 @@ Release date: 2024-11-13 [#404]: https://github.com/qutip/QuantumToolbox.jl/issues/404 [#405]: https://github.com/qutip/QuantumToolbox.jl/issues/405 [#408]: https://github.com/qutip/QuantumToolbox.jl/issues/408 +[#410]: https://github.com/qutip/QuantumToolbox.jl/issues/410 [#411]: https://github.com/qutip/QuantumToolbox.jl/issues/411 diff --git a/src/QuantumToolbox.jl b/src/QuantumToolbox.jl index e4835221b..d06979dda 100644 --- a/src/QuantumToolbox.jl +++ b/src/QuantumToolbox.jl @@ -53,7 +53,7 @@ import SciMLOperators: concretize import LinearSolve: LinearProblem, SciMLLinearSolveAlgorithm, KrylovJL_MINRES, KrylovJL_GMRES import DiffEqBase: get_tstops -import DiffEqCallbacks: PeriodicCallback, PresetTimeCallback, TerminateSteadyState +import DiffEqCallbacks: PeriodicCallback, FunctionCallingCallback, FunctionCallingAffect, TerminateSteadyState import OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm import OrdinaryDiffEqTsit5: Tsit5 import DiffEqNoiseProcess: RealWienerProcess!, RealWienerProcess diff --git a/src/time_evolution/callback_helpers/callback_helpers.jl b/src/time_evolution/callback_helpers/callback_helpers.jl index 4170635a8..e84085b47 100644 --- a/src/time_evolution/callback_helpers/callback_helpers.jl +++ b/src/time_evolution/callback_helpers/callback_helpers.jl @@ -56,8 +56,8 @@ function _generate_save_callback(e_ops, tlist, progress_bar, method) expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist)) - _save_affect! = method(e_ops_data, progr, Ref(1), expvals) - return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)) + _save_func = method(e_ops_data, progr, Ref(1), expvals) + return FunctionCallingCallback(_save_func, funcat = tlist) end function _generate_stochastic_save_callback(e_ops, sc_ops, tlist, store_measurement, progress_bar, method) @@ -69,8 +69,8 @@ function _generate_stochastic_save_callback(e_ops, sc_ops, tlist, store_measurem expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist)) m_expvals = getVal(store_measurement) ? Array{Float64}(undef, length(sc_ops), length(tlist) - 1) : nothing - _save_affect! = method(store_measurement, e_ops_data, m_ops_data, progr, Ref(1), expvals, m_expvals) - return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)) + _save_func = method(store_measurement, e_ops_data, m_ops_data, progr, Ref(1), expvals, m_expvals) + return FunctionCallingCallback(_save_func, funcat = tlist) end ## @@ -98,20 +98,20 @@ function _get_m_expvals(integrator::AbstractODESolution, method::Type{SF}) where if cb isa Nothing return nothing else - return cb.affect!.m_expvals + return cb.affect!.func.m_expvals end end #= With this function we extract the e_ops from the SaveFuncMCSolve `affect!` function of the callback of the integrator. - This callback can only be a PresetTimeCallback (DiscreteCallback). + This callback can only be a FunctionCallingCallback (DiscreteCallback). =# function _get_e_ops(integrator::AbstractODEIntegrator, method::Type{SF}) where {SF<:AbstractSaveFunc} cb = _get_save_callback(integrator, method) if cb isa Nothing return nothing else - return cb.affect!.e_ops + return cb.affect!.func.e_ops end end @@ -121,7 +121,7 @@ function _get_expvals(sol::AbstractODESolution, method::Type{SF}) where {SF<:Abs if cb isa Nothing return nothing else - return cb.affect!.expvals + return cb.affect!.func.expvals end end @@ -151,7 +151,7 @@ function _get_save_callback(cb::CallbackSet, method::Type{SF}) where {SF<:Abstra end end function _get_save_callback(cb::DiscreteCallback, ::Type{SF}) where {SF<:AbstractSaveFunc} - if typeof(cb.affect!) <: AbstractSaveFunc + if typeof(cb.affect!) <: FunctionCallingAffect && typeof(cb.affect!.func) <: AbstractSaveFunc return cb end return nothing diff --git a/src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl b/src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl index b0095e705..d96b50202 100644 --- a/src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl +++ b/src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl @@ -8,7 +8,7 @@ struct SaveFuncMCSolve{TE,IT,TEXPV} <: AbstractSaveFunc expvals::TEXPV end -(f::SaveFuncMCSolve)(integrator) = _save_func_mcsolve(integrator, f.e_ops, f.iter, f.expvals) +(f::SaveFuncMCSolve)(u, t, integrator) = _save_func_mcsolve(u, integrator, f.e_ops, f.iter, f.expvals) _get_save_callback_idx(cb, ::Type{SaveFuncMCSolve}) = _mcsolve_has_continuous_jump(cb) ? 1 : 2 @@ -52,10 +52,10 @@ end ## -function _save_func_mcsolve(integrator, e_ops, iter, expvals) +function _save_func_mcsolve(u, integrator, e_ops, iter, expvals) cache_mc = _mc_get_jump_callback(integrator).affect!.cache_mc - copyto!(cache_mc, integrator.u) + copyto!(cache_mc, u) normalize!(cache_mc) ψ = cache_mc _expect = op -> dot(ψ, op, ψ) @@ -114,8 +114,8 @@ function _generate_mcsolve_kwargs(ψ0, T, e_ops, tlist, c_ops, jump_callback, rn else expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist)) - _save_affect! = SaveFuncMCSolve(get_data.(e_ops), Ref(1), expvals) - cb2 = PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)) + _save_func = SaveFuncMCSolve(get_data.(e_ops), Ref(1), expvals) + cb2 = FunctionCallingCallback(_save_func, funcat = tlist) kwargs2 = haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, cb2, kwargs.callback),)) : merge(kwargs, (callback = CallbackSet(cb1, cb2),)) @@ -214,11 +214,11 @@ function _mcsolve_initialize_callbacks(cb::CallbackSet, tlist, traj_rng) if _mcsolve_has_continuous_jump(cb) idx = 1 - if cb_discrete[idx].affect! isa SaveFuncMCSolve - e_ops = cb_discrete[idx].affect!.e_ops - expvals = similar(cb_discrete[idx].affect!.expvals) - _save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals) - cb_save = (PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)),) + if cb_discrete[idx].affect!.func isa SaveFuncMCSolve + e_ops = cb_discrete[idx].affect!.func.e_ops + expvals = similar(cb_discrete[idx].affect!.func.expvals) + _save_func = SaveFuncMCSolve(e_ops, Ref(1), expvals) + cb_save = (FunctionCallingCallback(_save_func, funcat = tlist),) else cb_save = () end @@ -229,11 +229,11 @@ function _mcsolve_initialize_callbacks(cb::CallbackSet, tlist, traj_rng) return CallbackSet((cb_jump, cb_continuous[2:end]...), (cb_save..., cb_discrete[2:end]...)) else idx = 2 - if cb_discrete[idx].affect! isa SaveFuncMCSolve - e_ops = cb_discrete[idx].affect!.e_ops - expvals = similar(cb_discrete[idx].affect!.expvals) - _save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals) - cb_save = (PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)),) + if cb_discrete[idx].affect!.func isa SaveFuncMCSolve + e_ops = cb_discrete[idx].affect!.func.e_ops + expvals = similar(cb_discrete[idx].affect!.func.expvals) + _save_func = SaveFuncMCSolve(e_ops, Ref(1), expvals) + cb_save = (FunctionCallingCallback(_save_func, funcat = tlist),) else cb_save = () end diff --git a/src/time_evolution/callback_helpers/mesolve_callback_helpers.jl b/src/time_evolution/callback_helpers/mesolve_callback_helpers.jl index ee253d765..a24229d29 100644 --- a/src/time_evolution/callback_helpers/mesolve_callback_helpers.jl +++ b/src/time_evolution/callback_helpers/mesolve_callback_helpers.jl @@ -9,20 +9,20 @@ struct SaveFuncMESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing expvals::TEXPV end -(f::SaveFuncMESolve)(integrator) = _save_func_mesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals) -(f::SaveFuncMESolve{Nothing})(integrator) = _save_func(integrator, f.progr) +(f::SaveFuncMESolve)(u, t, integrator) = _save_func_mesolve(u, integrator, f.e_ops, f.progr, f.iter, f.expvals) +(f::SaveFuncMESolve{Nothing})(u, t, integrator) = _save_func(integrator, f.progr) _get_e_ops_data(e_ops, ::Type{SaveFuncMESolve}) = [_generate_mesolve_e_op(op) for op in e_ops] # Broadcasting generates type instabilities on Julia v1.10 ## # When e_ops is a list of operators -function _save_func_mesolve(integrator, e_ops, progr, iter, expvals) +function _save_func_mesolve(u, integrator, e_ops, progr, iter, expvals) # This is equivalent to tr(op * ρ), when both are matrices. # The advantage of using this convention is that We don't need # to reshape u to make it a matrix, but we reshape the e_ops once. - ρ = integrator.u + ρ = u _expect = op -> dot(op, ρ) @. expvals[:, iter[]] = _expect(e_ops) iter[] += 1 @@ -35,7 +35,7 @@ function _mesolve_callbacks_new_e_ops!(integrator::AbstractODEIntegrator, e_ops) if cb isa Nothing return nothing else - cb.affect!.e_ops .= e_ops # Only works if e_ops is a Vector of operators + cb.affect!.func.e_ops .= e_ops # Only works if e_ops is a Vector of operators return nothing end end diff --git a/src/time_evolution/callback_helpers/sesolve_callback_helpers.jl b/src/time_evolution/callback_helpers/sesolve_callback_helpers.jl index e5205ffc6..2bbff8bf0 100644 --- a/src/time_evolution/callback_helpers/sesolve_callback_helpers.jl +++ b/src/time_evolution/callback_helpers/sesolve_callback_helpers.jl @@ -9,16 +9,16 @@ struct SaveFuncSESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing expvals::TEXPV end -(f::SaveFuncSESolve)(integrator) = _save_func_sesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals) -(f::SaveFuncSESolve{Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both mesolve and sesolve +(f::SaveFuncSESolve)(u, t, integrator) = _save_func_sesolve(u, integrator, f.e_ops, f.progr, f.iter, f.expvals) +(f::SaveFuncSESolve{Nothing})(u, t, integrator) = _save_func(integrator, f.progr) # Common for both mesolve and sesolve _get_e_ops_data(e_ops, ::Type{SaveFuncSESolve}) = get_data.(e_ops) ## # When e_ops is a list of operators -function _save_func_sesolve(integrator, e_ops, progr, iter, expvals) - ψ = integrator.u +function _save_func_sesolve(u, integrator, e_ops, progr, iter, expvals) + ψ = u _expect = op -> dot(ψ, op, ψ) @. expvals[:, iter[]] = _expect(e_ops) iter[] += 1 diff --git a/src/time_evolution/callback_helpers/smesolve_callback_helpers.jl b/src/time_evolution/callback_helpers/smesolve_callback_helpers.jl index ab14bf81b..6fb39c14b 100644 --- a/src/time_evolution/callback_helpers/smesolve_callback_helpers.jl +++ b/src/time_evolution/callback_helpers/smesolve_callback_helpers.jl @@ -20,9 +20,9 @@ struct SaveFuncSMESolve{ m_expvals::TMEXPV end -(f::SaveFuncSMESolve)(integrator) = - _save_func_smesolve(integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals) -(f::SaveFuncSMESolve{false,Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both all solvers +(f::SaveFuncSMESolve)(u, t, integrator) = + _save_func_smesolve(u, integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals) +(f::SaveFuncSMESolve{false,Nothing})(u, t, integrator) = _save_func(integrator, f.progr) # Common for both all solvers _get_e_ops_data(e_ops, ::Type{SaveFuncSMESolve}) = _get_e_ops_data(e_ops, SaveFuncMESolve) _get_m_ops_data(sc_ops, ::Type{SaveFuncSMESolve}) = @@ -31,12 +31,12 @@ _get_m_ops_data(sc_ops, ::Type{SaveFuncSMESolve}) = ## # When e_ops is a list of operators -function _save_func_smesolve(integrator, e_ops, m_ops, progr, iter, expvals, m_expvals) +function _save_func_smesolve(u, integrator, e_ops, m_ops, progr, iter, expvals, m_expvals) # This is equivalent to tr(op * ρ), when both are matrices. # The advantage of using this convention is that We don't need # to reshape u to make it a matrix, but we reshape the e_ops once. - ρ = integrator.u + ρ = u _expect = op -> dot(op, ρ) diff --git a/src/time_evolution/callback_helpers/ssesolve_callback_helpers.jl b/src/time_evolution/callback_helpers/ssesolve_callback_helpers.jl index bd20d2d2c..c84cd7e5e 100644 --- a/src/time_evolution/callback_helpers/ssesolve_callback_helpers.jl +++ b/src/time_evolution/callback_helpers/ssesolve_callback_helpers.jl @@ -20,9 +20,9 @@ struct SaveFuncSSESolve{ m_expvals::TMEXPV end -(f::SaveFuncSSESolve)(integrator) = - _save_func_ssesolve(integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals) -(f::SaveFuncSSESolve{false,Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both all solvers +(f::SaveFuncSSESolve)(u, t, integrator) = + _save_func_ssesolve(u, integrator, f.e_ops, f.m_ops, f.progr, f.iter, f.expvals, f.m_expvals) +(f::SaveFuncSSESolve{false,Nothing})(u, t, integrator) = _save_func(integrator, f.progr) # Common for both all solvers _get_e_ops_data(e_ops, ::Type{SaveFuncSSESolve}) = get_data.(e_ops) _get_m_ops_data(sc_ops, ::Type{SaveFuncSSESolve}) = map(op -> Hermitian(get_data(op) + get_data(op)'), sc_ops) @@ -32,8 +32,8 @@ _get_save_callback_idx(cb, ::Type{SaveFuncSSESolve}) = 2 # The first one is the ## # When e_ops is a list of operators -function _save_func_ssesolve(integrator, e_ops, m_ops, progr, iter, expvals, m_expvals) - ψ = integrator.u +function _save_func_ssesolve(u, integrator, e_ops, m_ops, progr, iter, expvals, m_expvals) + ψ = u _expect = op -> dot(ψ, op, ψ) diff --git a/test/core-test/time_evolution.jl b/test/core-test/time_evolution.jl index 55be066a3..bd5be390a 100644 --- a/test/core-test/time_evolution.jl +++ b/test/core-test/time_evolution.jl @@ -85,11 +85,11 @@ @testset "Memory Allocations" begin allocs_tot = @allocations sesolve(H, ψ0, tlist, e_ops = e_ops, progress_bar = Val(false)) # Warm-up allocs_tot = @allocations sesolve(H, ψ0, tlist, e_ops = e_ops, progress_bar = Val(false)) - @test allocs_tot < 150 + @test allocs_tot < 110 allocs_tot = @allocations sesolve(H, ψ0, tlist, saveat = [tlist[end]], progress_bar = Val(false)) # Warm-up allocs_tot = @allocations sesolve(H, ψ0, tlist, saveat = [tlist[end]], progress_bar = Val(false)) - @test allocs_tot < 100 + @test allocs_tot < 90 end @testset "Type Inference sesolve" begin @@ -327,21 +327,21 @@ allocs_tot = @allocations mesolve(L, ψ0, tlist, e_ops = e_ops, progress_bar = Val(false)) # Warm-up allocs_tot = @allocations mesolve(L, ψ0, tlist, e_ops = e_ops, progress_bar = Val(false)) - @test allocs_tot < 210 + @test allocs_tot < 180 allocs_tot = @allocations mesolve(L, ψ0, tlist, saveat = [tlist[end]], progress_bar = Val(false)) # Warm-up allocs_tot = @allocations mesolve(L, ψ0, tlist, saveat = [tlist[end]], progress_bar = Val(false)) - @test allocs_tot < 120 + @test allocs_tot < 110 allocs_tot = @allocations mesolve(L_td, ψ0, tlist, e_ops = e_ops, progress_bar = Val(false), params = p) # Warm-up allocs_tot = @allocations mesolve(L_td, ψ0, tlist, e_ops = e_ops, progress_bar = Val(false), params = p) - @test allocs_tot < 210 + @test allocs_tot < 180 allocs_tot = @allocations mesolve(L_td, ψ0, tlist, progress_bar = Val(false), saveat = [tlist[end]], params = p) # Warm-up allocs_tot = @allocations mesolve(L_td, ψ0, tlist, progress_bar = Val(false), saveat = [tlist[end]], params = p) - @test allocs_tot < 120 + @test allocs_tot < 110 end @testset "Memory Allocations (mcsolve)" begin @@ -350,7 +350,7 @@ @allocations mcsolve(H, ψ0, tlist, c_ops, e_ops = e_ops, ntraj = ntraj, progress_bar = Val(false)) # Warm-up allocs_tot = @allocations mcsolve(H, ψ0, tlist, c_ops, e_ops = e_ops, ntraj = ntraj, progress_bar = Val(false)) - @test allocs_tot < 160 * ntraj + 500 # 150 allocations per trajectory + 500 for initialization + @test allocs_tot < 120 * ntraj + 400 # 150 allocations per trajectory + 500 for initialization allocs_tot = @allocations mcsolve( H, @@ -370,22 +370,23 @@ saveat = [tlist[end]], progress_bar = Val(false), ) - @test allocs_tot < 160 * ntraj + 300 # 100 allocations per trajectory + 300 for initialization + @test allocs_tot < 110 * ntraj + 300 # 100 allocations per trajectory + 300 for initialization end @testset "Memory Allocations (ssesolve)" begin + ntraj = 100 allocs_tot = - @allocations ssesolve(H, ψ0, tlist, c_ops, e_ops = e_ops, ntraj = 100, progress_bar = Val(false)) # Warm-up + @allocations ssesolve(H, ψ0, tlist, c_ops, e_ops = e_ops, ntraj = ntraj, progress_bar = Val(false)) # Warm-up allocs_tot = - @allocations ssesolve(H, ψ0, tlist, c_ops, e_ops = e_ops, ntraj = 100, progress_bar = Val(false)) - @test allocs_tot < 1950000 # TODO: Fix this high number of allocations + @allocations ssesolve(H, ψ0, tlist, c_ops, e_ops = e_ops, ntraj = ntraj, progress_bar = Val(false)) + @test allocs_tot < 1100 * ntraj + 400 # TODO: Fix this high number of allocations allocs_tot = @allocations ssesolve( H, ψ0, tlist, c_ops, - ntraj = 100, + ntraj = ntraj, saveat = [tlist[end]], progress_bar = Val(false), ) # Warm-up @@ -394,14 +395,15 @@ ψ0, tlist, c_ops, - ntraj = 100, + ntraj = ntraj, saveat = [tlist[end]], progress_bar = Val(false), ) - @test allocs_tot < 570000 # TODO: Fix this high number of allocations + @test allocs_tot < 1000 * ntraj + 300 # TODO: Fix this high number of allocations end @testset "Memory Allocations (smesolve)" begin + ntraj = 100 allocs_tot = @allocations smesolve( H, ψ0, @@ -409,7 +411,7 @@ c_ops_sme, sc_ops_sme, e_ops = e_ops, - ntraj = 100, + ntraj = ntraj, progress_bar = Val(false), ) # Warm-up allocs_tot = @allocations smesolve( @@ -419,10 +421,10 @@ c_ops_sme, sc_ops_sme, e_ops = e_ops, - ntraj = 100, + ntraj = ntraj, progress_bar = Val(false), ) - @test allocs_tot < 2750000 # TODO: Fix this high number of allocations + @test allocs_tot < 1100 * ntraj + 1800 # TODO: Fix this high number of allocations allocs_tot = @allocations smesolve( H, @@ -430,7 +432,7 @@ tlist, c_ops_sme, sc_ops_sme, - ntraj = 100, + ntraj = ntraj, saveat = [tlist[end]], progress_bar = Val(false), ) # Warm-up @@ -440,11 +442,56 @@ tlist, c_ops_sme, sc_ops_sme, - ntraj = 100, + ntraj = ntraj, + saveat = [tlist[end]], + progress_bar = Val(false), + ) + @test allocs_tot < 1000 * ntraj + 1500 # TODO: Fix this high number of allocations + + # Diagonal Noise Case + allocs_tot = @allocations smesolve( + H, + ψ0, + tlist, + c_ops_sme2, + sc_ops_sme2, + e_ops = e_ops, + ntraj = ntraj, + progress_bar = Val(false), + ) # Warm-up + allocs_tot = @allocations smesolve( + H, + ψ0, + tlist, + c_ops_sme2, + sc_ops_sme2, + e_ops = e_ops, + ntraj = 1, + progress_bar = Val(false), + ) + @test allocs_tot < 600 * ntraj + 1400 # TODO: Fix this high number of allocations + + allocs_tot = @allocations smesolve( + H, + ψ0, + tlist, + c_ops_sme2, + sc_ops_sme2, + ntraj = ntraj, + saveat = [tlist[end]], + progress_bar = Val(false), + ) # Warm-up + allocs_tot = @allocations smesolve( + H, + ψ0, + tlist, + c_ops_sme2, + sc_ops_sme2, + ntraj = 1, saveat = [tlist[end]], progress_bar = Val(false), ) - @test allocs_tot < 570000 # TODO: Fix this high number of allocations + @test allocs_tot < 550 * ntraj + 1000 # TODO: Fix this high number of allocations end @testset "Type Inference mesolve" begin