Skip to content

Commit 45a0945

Browse files
[no ci] Add support for OperatorKet input for mesolve and smesolve
1 parent 35e9c66 commit 45a0945

File tree

3 files changed

+40
-18
lines changed

3 files changed

+40
-18
lines changed

src/time_evolution/mesolve.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ where
3333
# Arguments
3434
3535
- `H`: Hamiltonian of the system ``\hat{H}``. It can be either a [`QuantumObject`](@ref), a [`QuantumObjectEvolution`](@ref), or a `Tuple` of operator-function pairs.
36-
- `ψ0`: Initial state of the system ``|\psi(0)\rangle``. It can be either a [`Ket`](@ref) or a [`Operator`](@ref).
36+
- `ψ0`: Initial state of the system ``|\psi(0)\rangle``. It can be either a [`Ket`](@ref), [`Operator`](@ref) or [`OperatorKet`](@ref).
3737
- `tlist`: List of times at which to save either the state or the expectation values of the system.
3838
- `c_ops`: List of collapse operators ``\{\hat{C}_n\}_n``. It can be either a `Vector` or a `Tuple`.
3939
- `e_ops`: List of operators for which to calculate expectation values. It can be either a `Vector` or a `Tuple`.
@@ -65,7 +65,7 @@ function mesolveProblem(
6565
kwargs...,
6666
) where {
6767
HOpType<:Union{OperatorQuantumObject,SuperOperatorQuantumObject},
68-
StateOpType<:Union{KetQuantumObject,OperatorQuantumObject},
68+
StateOpType<:Union{KetQuantumObject,OperatorQuantumObject,OperatorKetQuantumObject},
6969
}
7070
haskey(kwargs, :save_idxs) &&
7171
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))
@@ -76,7 +76,11 @@ function mesolveProblem(
7676
check_dimensions(L_evo, ψ0)
7777

7878
T = Base.promote_eltype(L_evo, ψ0)
79-
ρ0 = to_dense(_CType(T), mat2vec(ket2dm(ψ0).data)) # Convert it to dense vector with complex element type
79+
ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type
80+
to_dense(_CType(T), copy(ψ0.data))
81+
else
82+
to_dense(_CType(T), mat2vec(ket2dm(ψ0).data))
83+
end
8084
L = L_evo.data
8185

8286
kwargs2 = _merge_saveat(tlist, e_ops, DEFAULT_ODE_SOLVER_OPTIONS; kwargs...)
@@ -85,7 +89,7 @@ function mesolveProblem(
8589
tspan = (tlist[1], tlist[end])
8690
prob = ODEProblem{getVal(inplace),FullSpecialize}(L, ρ0, tspan, params; kwargs3...)
8791

88-
return TimeEvolutionProblem(prob, tlist, L_evo.dimensions)
92+
return TimeEvolutionProblem(prob, tlist, L_evo.dimensions, (isoperket=Val(isoperket(ψ0)),))
8993
end
9094

9195
@doc raw"""
@@ -117,7 +121,7 @@ where
117121
# Arguments
118122
119123
- `H`: Hamiltonian of the system ``\hat{H}``. It can be either a [`QuantumObject`](@ref), a [`QuantumObjectEvolution`](@ref), or a `Tuple` of operator-function pairs.
120-
- `ψ0`: Initial state of the system ``|\psi(0)\rangle``. It can be either a [`Ket`](@ref) or a [`Operator`](@ref).
124+
- `ψ0`: Initial state of the system ``|\psi(0)\rangle``. It can be either a [`Ket`](@ref), [`Operator`](@ref) or [`OperatorKet`](@ref).
121125
- `tlist`: List of times at which to save either the state or the expectation values of the system.
122126
- `c_ops`: List of collapse operators ``\{\hat{C}_n\}_n``. It can be either a `Vector` or a `Tuple`.
123127
- `alg`: The algorithm for the ODE solver. The default value is `Tsit5()`.
@@ -152,7 +156,7 @@ function mesolve(
152156
kwargs...,
153157
) where {
154158
HOpType<:Union{OperatorQuantumObject,SuperOperatorQuantumObject},
155-
StateOpType<:Union{KetQuantumObject,OperatorQuantumObject},
159+
StateOpType<:Union{KetQuantumObject,OperatorQuantumObject,OperatorKetQuantumObject},
156160
}
157161
prob = mesolveProblem(
158162
H,
@@ -173,7 +177,12 @@ end
173177
function mesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5())
174178
sol = solve(prob.prob, alg)
175179

176-
ρt = map-> QuantumObject(vec2mat(ϕ), type = Operator, dims = prob.dimensions), sol.u)
180+
# No type instabilities since `isoperket` is a Val, and so it is known at compile time
181+
if getVal(prob.kwargs.isoperket)
182+
ρt = map-> QuantumObject(ϕ, type = OperatorKet, dims = prob.dimensions), sol.u)
183+
else
184+
ρt = map-> QuantumObject(vec2mat(ϕ), type = Operator, dims = prob.dimensions), sol.u)
185+
end
177186

178187
return TimeEvolutionSol(
179188
prob.times,

src/time_evolution/smesolve.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
export smesolveProblem, smesolveEnsembleProblem, smesolve
22

3-
_smesolve_generate_state(u, dims) = QuantumObject(vec2mat(u), type = Operator, dims = dims)
3+
_smesolve_generate_state(u, dims, isoperket::Val{false}) = QuantumObject(vec2mat(u), type = Operator, dims = dims)
4+
_smesolve_generate_state(u, dims, isoperket::Val{true}) = QuantumObject(u, type = OperatorKet, dims = dims)
45

56
function _smesolve_update_coeff(u, p, t, op_vec)
67
return 2 * real(dot(op_vec, u)) #this is Tr[Sn * ρ + ρ * Sn']
@@ -47,7 +48,7 @@ Above, ``\hat{C}_i`` represent the collapse operators related to pure dissipatio
4748
# Arguments
4849
4950
- `H`: Hamiltonian of the system ``\hat{H}``. It can be either a [`QuantumObject`](@ref), a [`QuantumObjectEvolution`](@ref), or a `Tuple` of operator-function pairs.
50-
- `ψ0`: Initial state of the system ``|\psi(0)\rangle``. It can be either a [`Ket`](@ref) or a [`Operator`](@ref).
51+
- `ψ0`: Initial state of the system ``|\psi(0)\rangle``. It can be either a [`Ket`](@ref), [`Operator`](@ref) or [`OperatorKet`](@ref).
5152
- `tlist`: List of times at which to save either the state or the expectation values of the system.
5253
- `c_ops`: List of collapse operators ``\{\hat{C}_i\}_i``. It can be either a `Vector` or a `Tuple`.
5354
- `sc_ops`: List of stochastic collapse operators ``\{\hat{S}_n\}_n``. It can be either a `Vector`, a `Tuple` or a [`AbstractQuantumObject`](@ref). It is recommended to use the last case when only one operator is provided.
@@ -84,7 +85,7 @@ function smesolveProblem(
8485
progress_bar::Union{Val,Bool} = Val(true),
8586
store_measurement::Union{Val,Bool} = Val(false),
8687
kwargs...,
87-
) where {StateOpType<:Union{KetQuantumObject,OperatorQuantumObject}}
88+
) where {StateOpType<:Union{KetQuantumObject,OperatorQuantumObject,OperatorKetQuantumObject}}
8889
haskey(kwargs, :save_idxs) &&
8990
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))
9091

@@ -100,7 +101,11 @@ function smesolveProblem(
100101
dims = L_evo.dimensions
101102

102103
T = Base.promote_eltype(L_evo, ψ0)
103-
ρ0 = to_dense(_CType(T), mat2vec(ket2dm(ψ0).data)) # Convert it to dense vector with complex element type
104+
ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type
105+
to_dense(_CType(T), copy(ψ0.data))
106+
else
107+
to_dense(_CType(T), mat2vec(ket2dm(ψ0).data))
108+
end
104109

105110
progr = ProgressBar(length(tlist), enable = getVal(progress_bar))
106111

@@ -143,7 +148,7 @@ function smesolveProblem(
143148
kwargs3...,
144149
)
145150

146-
return TimeEvolutionProblem(prob, tlist, dims)
151+
return TimeEvolutionProblem(prob, tlist, dims, (isoperket=Val(isoperket(ψ0)),))
147152
end
148153

149154
@doc raw"""
@@ -188,7 +193,7 @@ Above, ``\hat{C}_i`` represent the collapse operators related to pure dissipatio
188193
# Arguments
189194
190195
- `H`: Hamiltonian of the system ``\hat{H}``. It can be either a [`QuantumObject`](@ref), a [`QuantumObjectEvolution`](@ref), or a `Tuple` of operator-function pairs.
191-
- `ψ0`: Initial state of the system ``|\psi(0)\rangle``. It can be either a [`Ket`](@ref) or a [`Operator`](@ref).
196+
- `ψ0`: Initial state of the system ``|\psi(0)\rangle``. It can be either a [`Ket`](@ref), [`Operator`](@ref) or [`OperatorKet`](@ref).
192197
- `tlist`: List of times at which to save either the state or the expectation values of the system.
193198
- `c_ops`: List of collapse operators ``\{\hat{C}_i\}_i``. It can be either a `Vector` or a `Tuple`.
194199
- `sc_ops`: List of stochastic collapse operators ``\{\hat{S}_n\}_n``. It can be either a `Vector`, a `Tuple` or a [`AbstractQuantumObject`](@ref). It is recommended to use the last case when only one operator is provided.
@@ -233,7 +238,7 @@ function smesolveEnsembleProblem(
233238
progress_bar::Union{Val,Bool} = Val(true),
234239
store_measurement::Union{Val,Bool} = Val(false),
235240
kwargs...,
236-
) where {StateOpType<:Union{KetQuantumObject,OperatorQuantumObject}}
241+
) where {StateOpType<:Union{KetQuantumObject,OperatorQuantumObject,OperatorKetQuantumObject}}
237242
_prob_func =
238243
isnothing(prob_func) ?
239244
_ensemble_dispatch_prob_func(
@@ -266,7 +271,7 @@ function smesolveEnsembleProblem(
266271
EnsembleProblem(prob_sme, prob_func = _prob_func, output_func = _output_func[1], safetycopy = true),
267272
prob_sme.times,
268273
prob_sme.dimensions,
269-
(progr = _output_func[2], channel = _output_func[3]),
274+
merge(prob_sme.kwargs, (progr = _output_func[2], channel = _output_func[3])),
270275
)
271276

272277
return ensemble_prob
@@ -315,7 +320,7 @@ Above, ``\hat{C}_i`` represent the collapse operators related to pure dissipatio
315320
# Arguments
316321
317322
- `H`: Hamiltonian of the system ``\hat{H}``. It can be either a [`QuantumObject`](@ref), a [`QuantumObjectEvolution`](@ref), or a `Tuple` of operator-function pairs.
318-
- `ψ0`: Initial state of the system ``|\psi(0)\rangle``. It can be either a [`Ket`](@ref) or a [`Operator`](@ref).
323+
- `ψ0`: Initial state of the system ``|\psi(0)\rangle``. It can be either a [`Ket`](@ref), [`Operator`](@ref) or [`OperatorKet`](@ref).
319324
- `tlist`: List of times at which to save either the state or the expectation values of the system.
320325
- `c_ops`: List of collapse operators ``\{\hat{C}_i\}_i``. It can be either a `Vector` or a `Tuple`.
321326
- `sc_ops`: List of stochastic collapse operators ``\{\hat{S}_n\}_n``. It can be either a `Vector`, a `Tuple` or a [`AbstractQuantumObject`](@ref). It is recommended to use the last case when only one operator is provided.
@@ -362,7 +367,7 @@ function smesolve(
362367
progress_bar::Union{Val,Bool} = Val(true),
363368
store_measurement::Union{Val,Bool} = Val(false),
364369
kwargs...,
365-
) where {StateOpType<:Union{KetQuantumObject,OperatorQuantumObject}}
370+
) where {StateOpType<:Union{KetQuantumObject,OperatorQuantumObject,OperatorKetQuantumObject}}
366371
ensemble_prob = smesolveEnsembleProblem(
367372
H,
368373
ψ0,
@@ -406,7 +411,8 @@ function smesolve(
406411
_expvals_all =
407412
_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_expvals(sol[:, i], SaveFuncMESolve), eachindex(sol))
408413
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP
409-
states = map(i -> _smesolve_generate_state.(sol[:, i].u, Ref(dims)), eachindex(sol))
414+
415+
states = map(i -> _smesolve_generate_state.(sol[:, i].u, Ref(dims), ens_prob.kwargs.isoperket), eachindex(sol))
410416

411417
_m_expvals =
412418
_m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSMESolve), eachindex(sol))

test/core-test/time_evolution.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@
158158
)
159159
sol_sme3 = smesolve(H, ψ0, tlist, c_ops_sme2, sc_ops_sme2, e_ops = e_ops, progress_bar = Val(false))
160160

161+
# For testing the `OperatorKet` input
162+
sol_me4 = mesolve(H, operator_to_vector(ket2dm(ψ0)), tlist, c_ops, saveat=saveat, progress_bar = Val(false))
163+
sol_sme4 = smesolve(H, ψ0, tlist, c_ops_sme, sc_ops_sme, saveat=saveat, ntraj=10, progress_bar = Val(false), rng = MersenneTwister(12))
164+
sol_sme5 = smesolve(H, operator_to_vector(ket2dm(ψ0)), tlist, c_ops_sme, sc_ops_sme, saveat=saveat, ntraj=10, progress_bar = Val(false), rng = MersenneTwister(12))
165+
161166
ρt_mc = [ket2dm.(normalize.(states)) for states in sol_mc_states.states]
162167
expect_mc_states = mapreduce(states -> expect.(Ref(e_ops[1]), states), hcat, ρt_mc)
163168
expect_mc_states_mean = sum(expect_mc_states, dims = 2) / size(expect_mc_states, 2)
@@ -190,6 +195,7 @@
190195
@test length(sol_me3.states) == length(saveat)
191196
@test size(sol_me3.expect) == (length(e_ops), length(tlist))
192197
@test sol_me3.expect[1, saveat_idxs] expect(e_ops[1], sol_me3.states) atol = 1e-6
198+
@test all([sol_me3.states[i] vector_to_operator(sol_me4.states[i]) for i in eachindex(saveat)])
193199
@test length(sol_mc.times) == length(tlist)
194200
@test size(sol_mc.expect) == (length(e_ops), length(tlist))
195201
@test length(sol_mc_states.times) == length(tlist)
@@ -202,6 +208,7 @@
202208
@test isnothing(sol_sme.measurement)
203209
@test size(sol_sse2.measurement) == (length(c_ops), 20, length(tlist) - 1)
204210
@test size(sol_sme2.measurement) == (length(sc_ops_sme), 20, length(tlist) - 1)
211+
@test all([sol_sme4.states[j][i] vector_to_operator(sol_sme5.states[j][i]) for i in eachindex(saveat), j in 1:10])
205212

206213
@test sol_me_string ==
207214
"Solution of time evolution\n" *

0 commit comments

Comments
 (0)