|
1 | 1 | export sesolveProblem, sesolve |
2 | 2 |
|
| 3 | +# When e_ops is Nothing |
3 | 4 | function _save_func_sesolve(integrator) |
4 | | - internal_params = integrator.p |
5 | | - progr = internal_params.progr |
6 | | - |
7 | | - if !internal_params.is_empty_e_ops |
8 | | - e_ops = internal_params.e_ops |
9 | | - expvals = internal_params.expvals |
| 5 | + next!(integrator.p.progr) |
| 6 | + return u_modified!(integrator, false) |
| 7 | +end |
10 | 8 |
|
| 9 | +# When e_ops is a list of operators |
| 10 | +function _save_func_sesolve(integrator, e_ops, is_empty_e_ops) |
| 11 | + expvals = integrator.p.expvals |
| 12 | + progr = integrator.p.progr |
| 13 | + if !is_empty_e_ops |
11 | 14 | ψ = integrator.u |
12 | | - _expect = op -> dot(ψ, op, ψ) |
| 15 | + _expect = op -> dot(ψ, get_data(op), ψ) |
13 | 16 | @. expvals[:, progr.counter[]+1] = _expect(e_ops) |
14 | 17 | end |
15 | | - next!(progr) |
16 | | - return u_modified!(integrator, false) |
| 18 | + return _save_func_sesolve(integrator) |
| 19 | +end |
| 20 | + |
| 21 | +# Generate the callback depending on the e_ops type |
| 22 | +function _generate_sesolve_callback(e_ops::Nothing, tlist) |
| 23 | + f = integrator -> _save_func_sesolve(integrator) |
| 24 | + return PresetTimeCallback(tlist, f, save_positions = (false, false)) |
| 25 | +end |
| 26 | + |
| 27 | +function _generate_sesolve_callback(e_ops, tlist) |
| 28 | + is_empty_e_ops = isempty(e_ops) |
| 29 | + f = integrator -> _save_func_sesolve(integrator, e_ops, is_empty_e_ops) |
| 30 | + return PresetTimeCallback(tlist, f, save_positions = (false, false)) |
17 | 31 | end |
18 | 32 |
|
19 | | -function _generate_sesolve_kwargs_with_callback(tlist, kwargs) |
20 | | - cb1 = PresetTimeCallback(tlist, _save_func_sesolve, save_positions = (false, false)) |
| 33 | +function _merge_sesolve_kwargs_with_callback(kwargs, cb) |
21 | 34 | kwargs2 = |
22 | | - haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(kwargs.callback, cb1),)) : |
23 | | - merge(kwargs, (callback = cb1,)) |
| 35 | + haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(kwargs.callback, cb),)) : |
| 36 | + merge(kwargs, (callback = cb,)) |
24 | 37 |
|
25 | 38 | return kwargs2 |
26 | 39 | end |
27 | 40 |
|
| 41 | +# Multiple dispatch depending on the progress_bar and e_ops types |
28 | 42 | function _generate_sesolve_kwargs(e_ops, progress_bar::Val{true}, tlist, kwargs) |
29 | | - return _generate_sesolve_kwargs_with_callback(tlist, kwargs) |
| 43 | + cb = _generate_sesolve_callback(e_ops, tlist) |
| 44 | + return _merge_sesolve_kwargs_with_callback(kwargs, cb) |
30 | 45 | end |
31 | 46 |
|
32 | 47 | function _generate_sesolve_kwargs(e_ops, progress_bar::Val{false}, tlist, kwargs) |
33 | 48 | if e_ops isa Nothing |
34 | 49 | return kwargs |
35 | 50 | end |
36 | | - return _generate_sesolve_kwargs_with_callback(tlist, kwargs) |
| 51 | + cb = _generate_sesolve_callback(e_ops, tlist) |
| 52 | + return _merge_sesolve_kwargs_with_callback(kwargs, cb) |
37 | 53 | end |
38 | 54 |
|
39 | 55 | _sesolve_make_U_QobjEvo(H::QuantumObjectEvolution{<:MatrixOperator}) = |
@@ -103,31 +119,23 @@ function sesolveProblem( |
103 | 119 |
|
104 | 120 | if e_ops isa Nothing |
105 | 121 | expvals = Array{ComplexF64}(undef, 0, length(tlist)) |
106 | | - e_ops_data = () |
107 | 122 | is_empty_e_ops = true |
108 | 123 | else |
109 | 124 | expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist)) |
110 | | - e_ops_data = get_data.(e_ops) |
111 | 125 | is_empty_e_ops = isempty(e_ops) |
112 | 126 | end |
113 | 127 |
|
114 | | - p = ( |
115 | | - e_ops = e_ops_data, |
116 | | - expvals = expvals, |
117 | | - progr = progr, |
118 | | - times = tlist, |
119 | | - Hdims = H_evo.dims, |
120 | | - is_empty_e_ops = is_empty_e_ops, |
121 | | - params..., |
122 | | - ) |
| 128 | + p = QuantumTimeEvoParameters(expvals, progr, params) |
123 | 129 |
|
124 | 130 | saveat = is_empty_e_ops ? tlist : [tlist[end]] |
125 | 131 | default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat) |
126 | 132 | kwargs2 = merge(default_values, kwargs) |
127 | 133 | kwargs3 = _generate_sesolve_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2) |
128 | 134 |
|
129 | 135 | tspan = (tlist[1], tlist[end]) |
130 | | - return ODEProblem{true,FullSpecialize}(U, ψ0, tspan, p; kwargs3...) |
| 136 | + prob = ODEProblem{true,FullSpecialize}(U, ψ0, tspan, p; kwargs3...) |
| 137 | + |
| 138 | + return QuantumTimeEvoProblem(prob, tlist, H_evo.dims) |
131 | 139 | end |
132 | 140 |
|
133 | 141 | @doc raw""" |
@@ -186,13 +194,13 @@ function sesolve( |
186 | 194 | return sesolve(prob, alg) |
187 | 195 | end |
188 | 196 |
|
189 | | -function sesolve(prob::ODEProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5()) |
190 | | - sol = solve(prob, alg) |
| 197 | +function sesolve(prob::QuantumTimeEvoProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5()) |
| 198 | + sol = solve(prob.prob, alg) |
191 | 199 |
|
192 | | - ψt = map(ϕ -> QuantumObject(ϕ, type = Ket, dims = sol.prob.p.Hdims), sol.u) |
| 200 | + ψt = map(ϕ -> QuantumObject(ϕ, type = Ket, dims = prob.dims), sol.u) |
193 | 201 |
|
194 | 202 | return TimeEvolutionSol( |
195 | | - sol.prob.p.times, |
| 203 | + prob.times, |
196 | 204 | ψt, |
197 | 205 | sol.prob.p.expvals, |
198 | 206 | sol.retcode, |
|
0 commit comments