Skip to content

Commit 9665194

Browse files
Added operator support. Look at testing.ipynb ##Testing Solver Extensions for usage.
1 parent 571abc9 commit 9665194

File tree

10 files changed

+1172
-62
lines changed

10 files changed

+1172
-62
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,4 @@ Manifest.toml
1010
benchmarks/benchmarks_output.json
1111

1212
.ipynb_checkpoints
13-
*.ipynb
1413
.devcontainer/*

src/qobj/functions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,12 @@ Converts a sparse QuantumObject to a dense QuantumObject.
119119
to_dense(A::QuantumObject) = QuantumObject(to_dense(A.data), A.type, A.dimensions)
120120
to_dense(A::MT) where {MT<:AbstractSparseArray} = Array(A)
121121
to_dense(A::MT) where {MT<:AbstractArray} = A
122+
to_dense(A::Diagonal) = diagm(A.diag)
122123

123124
to_dense(::Type{T}, A::AbstractSparseArray) where {T<:Number} = Array{T}(A)
124125
to_dense(::Type{T1}, A::AbstractArray{T2}) where {T1<:Number,T2<:Number} = Array{T1}(A)
125126
to_dense(::Type{T}, A::AbstractArray{T}) where {T<:Number} = A
127+
to_dense(::Type{T}, A::Diagonal{T}) where {T<:Number} = diagm(A.diag)
126128

127129
function to_dense(::Type{M}) where {M<:Union{Diagonal,SparseMatrixCSC}}
128130
T = M

src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ struct LindbladJump{
1818
T2,
1919
RNGType<:AbstractRNG,
2020
RandT,
21-
CT<:AbstractVector,
21+
CT<:AbstractArray,
2222
WT<:AbstractVector,
2323
JTT<:AbstractVector,
2424
JWT<:AbstractVector,

src/time_evolution/mcsolve.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ function _mcsolve_output_func(sol, i)
2020
return (sol, false)
2121
end
2222

23-
function _normalize_state!(u, dims, normalize_states)
23+
function _normalize_state!(u, dims, normalize_states, type)
2424
getVal(normalize_states) && normalize!(u)
25-
return QuantumObject(u, Ket(), dims)
25+
return QuantumObject(u, type(), dims)
2626
end
2727

2828
function _mcsolve_make_Heff_QobjEvo(H::QuantumObject, c_ops)
@@ -110,15 +110,15 @@ If the environmental measurements register a quantum jump, the wave function und
110110
"""
111111
function mcsolveProblem(
112112
H::Union{AbstractQuantumObject{Operator},Tuple},
113-
ψ0::QuantumObject{Ket},
113+
ψ0::QuantumObject{X},
114114
tlist::AbstractVector,
115115
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
116116
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
117117
params = NullParameters(),
118118
rng::AbstractRNG = default_rng(),
119119
jump_callback::TJC = ContinuousLindbladJumpCallback(),
120120
kwargs...,
121-
) where {TJC<:LindbladJumpCallbackType}
121+
) where {TJC<:LindbladJumpCallbackType,X<:Union{Ket,Operator}}
122122
haskey(kwargs, :save_idxs) &&
123123
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))
124124

@@ -221,7 +221,7 @@ If the environmental measurements register a quantum jump, the wave function und
221221
"""
222222
function mcsolveEnsembleProblem(
223223
H::Union{AbstractQuantumObject{Operator},Tuple},
224-
ψ0::QuantumObject{Ket},
224+
ψ0::QuantumObject{X},
225225
tlist::AbstractVector,
226226
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
227227
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
@@ -234,7 +234,7 @@ function mcsolveEnsembleProblem(
234234
prob_func::Union{Function,Nothing} = nothing,
235235
output_func::Union{Tuple,Nothing} = nothing,
236236
kwargs...,
237-
) where {TJC<:LindbladJumpCallbackType}
237+
) where {TJC<:LindbladJumpCallbackType,X<:Union{Ket,Operator}}
238238
_prob_func = isnothing(prob_func) ? _ensemble_dispatch_prob_func(rng, ntraj, tlist, _mcsolve_prob_func) : prob_func
239239
_output_func =
240240
output_func isa Nothing ?
@@ -261,6 +261,7 @@ function mcsolveEnsembleProblem(
261261
ensemble_prob = TimeEvolutionProblem(
262262
EnsembleProblem(prob_mc.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false),
263263
prob_mc.times,
264+
X,
264265
prob_mc.dimensions,
265266
(progr = _output_func[2], channel = _output_func[3]),
266267
)
@@ -358,7 +359,7 @@ If the environmental measurements register a quantum jump, the wave function und
358359
"""
359360
function mcsolve(
360361
H::Union{AbstractQuantumObject{Operator},Tuple},
361-
ψ0::QuantumObject{Ket},
362+
ψ0::QuantumObject{X},
362363
tlist::AbstractVector,
363364
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
364365
alg::AbstractODEAlgorithm = DP5(),
@@ -374,7 +375,7 @@ function mcsolve(
374375
keep_runs_results::Union{Val,Bool} = Val(false),
375376
normalize_states::Union{Val,Bool} = Val(true),
376377
kwargs...,
377-
) where {TJC<:LindbladJumpCallbackType}
378+
) where {TJC<:LindbladJumpCallbackType} where {X<:Union{Ket,Operator}}
378379
ens_prob_mc = mcsolveEnsembleProblem(
379380
H,
380381
ψ0,
@@ -415,7 +416,11 @@ function mcsolve(
415416
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP
416417

417418
# stack to transform Vector{Vector{QuantumObject}} -> Matrix{QuantumObject}
418-
states_all = stack(map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol)), dims = 1)
419+
# states_all = stack(map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol)), dims = 1)
420+
states_all = stack(
421+
map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states, ens_prob_mc.states_type), eachindex(sol)),
422+
dims = 1,
423+
)
419424

420425
col_times = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_times, eachindex(sol))
421426
col_which = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_which, eachindex(sol))

src/time_evolution/mesolve.jl

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@ _mesolve_make_L_QobjEvo(H::Union{QuantumObjectEvolution,Tuple}, c_ops) = liouvil
66
_mesolve_make_L_QobjEvo(H::Nothing, c_ops::Nothing) = throw(ArgumentError("Both H and
77
c_ops are Nothing. You are probably running the wrong function."))
88

9-
function _gen_mesolve_solution(sol, times, dimensions, isoperket::Val)
10-
if getVal(isoperket)
11-
ρt = map-> QuantumObject(ϕ, type = OperatorKet(), dims = dimensions), sol.u)
9+
function _gen_mesolve_solution(sol, prob::TimeEvolutionProblem{X}) where {X<:Union{Operator,OperatorKet,SuperOperator}}
10+
if X() == Operator()
11+
ρt = map-> QuantumObject(vec2mat(ϕ), type = X(), dims = prob.dimensions), sol.u)
1212
else
13-
ρt = map-> QuantumObject(vec2mat(ϕ), type = Operator(), dims = dimensions), sol.u)
13+
ρt = map-> QuantumObject(ϕ, type = X(), dims = prob.dimensions), sol.u)
1414
end
1515

1616
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility
1717

1818
return TimeEvolutionSol(
19-
times,
19+
prob.times,
2020
sol.t,
2121
ρt,
2222
_get_expvals(sol, SaveFuncMESolve),
@@ -86,8 +86,8 @@ function mesolveProblem(
8686
progress_bar::Union{Val,Bool} = Val(true),
8787
inplace::Union{Val,Bool} = Val(true),
8888
kwargs...,
89-
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}}
90-
(isoper(H) && isket(ψ0) && isnothing(c_ops)) && return sesolveProblem(
89+
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}}
90+
(isoper(H) && (isket(ψ0) || isoper(ψ0)) && isnothing(c_ops)) && return sesolveProblem(
9191
H,
9292
ψ0,
9393
tlist;
@@ -107,11 +107,27 @@ function mesolveProblem(
107107
check_dimensions(L_evo, ψ0)
108108

109109
T = Base.promote_eltype(L_evo, ψ0)
110-
ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type
111-
to_dense(_complex_float_type(T), copy(ψ0.data))
110+
# ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type
111+
# to_dense(_complex_float_type(T), copy(ψ0.data))
112+
# else
113+
# to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
114+
# end
115+
if isoper(ψ0)
116+
ρ0 = to_dense(_complex_float_type(T), mat2vec(ψ0.data))
117+
state_type = Operator()
118+
elseif isoperket(ψ0)
119+
ρ0 = to_dense(_complex_float_type(T), copy(ψ0.data))
120+
state_type = OperatorKet()
121+
elseif isket(ψ0)
122+
ρ0 = to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
123+
state_type = Operator()
124+
elseif issuper(ψ0)
125+
ρ0 = to_dense(_complex_float_type(T), copy(ψ0.data))
126+
state_type = SuperOperator()
112127
else
113-
to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
128+
throw(ArgumentError("Unsupported state type for ψ0 in mesolveProblem."))
114129
end
130+
115131
L = cache_operator(L_evo.data, ρ0)
116132

117133
kwargs2 = _merge_saveat(tlist, e_ops, DEFAULT_ODE_SOLVER_OPTIONS; kwargs...)
@@ -122,7 +138,7 @@ function mesolveProblem(
122138

123139
prob = ODEProblem{getVal(inplace),FullSpecialize}(L, ρ0, tspan, params; kwargs4...)
124140

125-
return TimeEvolutionProblem(prob, tlist, L_evo.dimensions, (isoperket = Val(isoperket(ψ0)),))
141+
return TimeEvolutionProblem(prob, tlist, state_type, L_evo.dimensions)#, (isoperket = Val(isoperket(ψ0)),))
126142
end
127143

128144
@doc raw"""
@@ -188,8 +204,8 @@ function mesolve(
188204
progress_bar::Union{Val,Bool} = Val(true),
189205
inplace::Union{Val,Bool} = Val(true),
190206
kwargs...,
191-
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}}
192-
(isoper(H) && isket(ψ0) && isnothing(c_ops)) && return sesolve(
207+
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}}
208+
(isoper(H) && (isket(ψ0) || isoper(ψ0)) && isnothing(c_ops)) && return sesolve(
193209
H,
194210
ψ0,
195211
tlist;
@@ -230,7 +246,7 @@ end
230246
function mesolve(prob::TimeEvolutionProblem, alg::AbstractODEAlgorithm = DP5(); kwargs...)
231247
sol = solve(prob.prob, alg; kwargs...)
232248

233-
return _gen_mesolve_solution(sol, prob.times, prob.dimensions, prob.kwargs.isoperket)
249+
return _gen_mesolve_solution(sol, prob)#, prob.kwargs.isoperket)
234250
end
235251

236252
@doc raw"""
@@ -298,8 +314,8 @@ function mesolve_map(
298314
params::Union{NullParameters,Tuple} = NullParameters(),
299315
progress_bar::Union{Val,Bool} = Val(true),
300316
kwargs...,
301-
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}}
302-
(isoper(H) && all(isket, ψ0) && isnothing(c_ops)) && return sesolve_map(
317+
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}}
318+
(isoper(H) && (all(isket, ψ0) || all(isoper, ψ0)) && isnothing(c_ops)) && return sesolve_map(
303319
H,
304320
ψ0,
305321
tlist;
@@ -315,10 +331,16 @@ function mesolve_map(
315331
# Convert to appropriate format based on state type
316332
ψ0_iter = map(ψ0) do state
317333
T = _complex_float_type(eltype(state))
318-
if isoperket(state)
319-
to_dense(T, copy(state.data))
334+
if isoper(state)
335+
to_dense(_complex_float_type(T), mat2vec(state.data))
336+
elseif isoperket(state)
337+
to_dense(_complex_float_type(T), copy(state.data))
338+
elseif isket(state)
339+
to_dense(_complex_float_type(T), mat2vec(ket2dm(state).data))
340+
elseif issuper(state)
341+
to_dense(_complex_float_type(T), copy(state.data))
320342
else
321-
to_dense(T, mat2vec(ket2dm(state).data))
343+
throw(ArgumentError("Unsupported state type for ψ0 in mesolveProblem."))
322344
end
323345
end
324346
if params isa NullParameters
@@ -347,7 +369,7 @@ mesolve_map(
347369
tlist::AbstractVector,
348370
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
349371
kwargs...,
350-
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}} =
372+
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}} =
351373
mesolve_map(H, [ψ0], tlist, c_ops; kwargs...)
352374

353375
# this method is for advanced usage
@@ -357,14 +379,14 @@ mesolve_map(
357379
#
358380
# Return: An array of TimeEvolutionSol objects with the size same as the given iter.
359381
function mesolve_map(
360-
prob::TimeEvolutionProblem{<:ODEProblem},
382+
prob::TimeEvolutionProblem{StateOpType,<:ODEProblem},
361383
iter::AbstractArray,
362384
alg::AbstractODEAlgorithm = DP5(),
363385
ensemblealg::EnsembleAlgorithm = EnsembleThreads();
364386
prob_func::Union{Function,Nothing} = nothing,
365387
output_func::Union{Tuple,Nothing} = nothing,
366388
progress_bar::Union{Val,Bool} = Val(true),
367-
)
389+
) where {StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}}
368390
# generate ensemble problem
369391
ntraj = length(iter)
370392
_prob_func = isnothing(prob_func) ? (prob, i, repeat) -> _se_me_map_prob_func(prob, i, repeat, iter) : prob_func
@@ -380,14 +402,14 @@ function mesolve_map(
380402
ens_prob = TimeEvolutionProblem(
381403
EnsembleProblem(prob.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false),
382404
prob.times,
405+
StateOpType(),
383406
prob.dimensions,
384-
(progr = _output_func[2], channel = _output_func[3], isoperket = prob.kwargs.isoperket),
407+
(progr = _output_func[2], channel = _output_func[3]),
385408
)
386409

387410
sol = _ensemble_dispatch_solve(ens_prob, alg, ensemblealg, ntraj)
388411

389412
# handle solution and make it become an Array of TimeEvolutionSol
390-
sol_vec =
391-
[_gen_mesolve_solution(sol[:, i], prob.times, prob.dimensions, prob.kwargs.isoperket) for i in eachindex(sol)] # map is type unstable
413+
sol_vec = [_gen_mesolve_solution(sol[:, i], prob) for i in eachindex(sol)] # map is type unstable
392414
return reshape(sol_vec, size(iter))
393415
end

0 commit comments

Comments
 (0)