|
| 1 | +#= |
| 2 | +This file contains helper functions for callbacks. The affect! function are defined taking advantage of the Julia struct, which allows to store some cache exclusively for the callback. |
| 3 | +=# |
| 4 | + |
| 5 | +## |
| 6 | + |
| 7 | +# Multiple dispatch depending on the progress_bar and e_ops types |
| 8 | +function _generate_se_me_kwargs(e_ops, progress_bar, tlist, kwargs, method) |
| 9 | + cb = _generate_save_callback(e_ops, tlist, progress_bar, method) |
| 10 | + return _merge_kwargs_with_callback(kwargs, cb) |
| 11 | +end |
| 12 | +_generate_se_me_kwargs(e_ops::Nothing, progress_bar::Val{false}, tlist, kwargs, method) = kwargs |
| 13 | + |
| 14 | +function _merge_kwargs_with_callback(kwargs, cb) |
| 15 | + kwargs2 = |
| 16 | + haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb, kwargs.callback),)) : |
| 17 | + merge(kwargs, (callback = cb,)) |
| 18 | + |
| 19 | + return kwargs2 |
| 20 | +end |
| 21 | + |
| 22 | +function _generate_save_callback(e_ops, tlist, progress_bar, method) |
| 23 | + e_ops_data = e_ops isa Nothing ? nothing : _get_e_ops_data(e_ops, method) |
| 24 | + |
| 25 | + progr = getVal(progress_bar) ? ProgressBar(length(tlist), enable = getVal(progress_bar)) : nothing |
| 26 | + |
| 27 | + expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist)) |
| 28 | + |
| 29 | + _save_affect! = method(e_ops_data, progr, Ref(1), expvals) |
| 30 | + return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)) |
| 31 | +end |
| 32 | + |
| 33 | +_get_e_ops_data(e_ops, ::Type{SaveFuncSESolve}) = get_data.(e_ops) |
| 34 | +_get_e_ops_data(e_ops, ::Type{SaveFuncMESolve}) = [_generate_mesolve_e_op(op) for op in e_ops] # Broadcasting generates type instabilities on Julia v1.10 |
| 35 | + |
| 36 | +_generate_mesolve_e_op(op) = mat2vec(adjoint(get_data(op))) |
| 37 | + |
| 38 | +## |
| 39 | + |
| 40 | +# When e_ops is Nothing. Common for both mesolve and sesolve |
| 41 | +function _save_func(integrator, progr) |
| 42 | + next!(progr) |
| 43 | + u_modified!(integrator, false) |
| 44 | + return nothing |
| 45 | +end |
| 46 | + |
| 47 | +# When progr is Nothing. Common for both mesolve and sesolve |
| 48 | +function _save_func(integrator, progr::Nothing) |
| 49 | + u_modified!(integrator, false) |
| 50 | + return nothing |
| 51 | +end |
| 52 | + |
| 53 | +## |
| 54 | + |
| 55 | +# Get the e_ops from a given AbstractODESolution. Valid for `sesolve`, `mesolve` and `ssesolve`. |
| 56 | +function _se_me_sse_get_expvals(sol::AbstractODESolution) |
| 57 | + cb = _se_me_sse_get_save_callback(sol) |
| 58 | + if cb isa Nothing |
| 59 | + return nothing |
| 60 | + else |
| 61 | + return cb.affect!.expvals |
| 62 | + end |
| 63 | +end |
| 64 | + |
| 65 | +function _se_me_sse_get_save_callback(sol::AbstractODESolution) |
| 66 | + kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl |
| 67 | + if hasproperty(kwargs, :callback) |
| 68 | + return _se_me_sse_get_save_callback(kwargs.callback) |
| 69 | + else |
| 70 | + return nothing |
| 71 | + end |
| 72 | +end |
| 73 | +_se_me_sse_get_save_callback(integrator::AbstractODEIntegrator) = _se_me_sse_get_save_callback(integrator.opts.callback) |
| 74 | +function _se_me_sse_get_save_callback(cb::CallbackSet) |
| 75 | + cbs_discrete = cb.discrete_callbacks |
| 76 | + if length(cbs_discrete) > 0 |
| 77 | + _cb = cb.discrete_callbacks[1] |
| 78 | + return _se_me_sse_get_save_callback(_cb) |
| 79 | + else |
| 80 | + return nothing |
| 81 | + end |
| 82 | +end |
| 83 | +_se_me_sse_get_save_callback(cb::DiscreteCallback) = |
| 84 | + if (cb.affect! isa SaveFuncSESolve) || (cb.affect! isa SaveFuncMESolve) |
| 85 | + return cb |
| 86 | + else |
| 87 | + return nothing |
| 88 | + end |
| 89 | +_se_me_sse_get_save_callback(cb::ContinuousCallback) = nothing |
0 commit comments