Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ Manifest.toml
benchmarks/benchmarks_output.json

.ipynb_checkpoints
*.ipynb
.devcontainer/*
2 changes: 2 additions & 0 deletions src/qobj/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,12 @@ Converts a sparse QuantumObject to a dense QuantumObject.
to_dense(A::QuantumObject) = QuantumObject(to_dense(A.data), A.type, A.dimensions)
to_dense(A::MT) where {MT<:AbstractSparseArray} = Array(A)
to_dense(A::MT) where {MT<:AbstractArray} = A
to_dense(A::Diagonal) = diagm(A.diag)

to_dense(::Type{T}, A::AbstractSparseArray) where {T<:Number} = Array{T}(A)
to_dense(::Type{T1}, A::AbstractArray{T2}) where {T1<:Number,T2<:Number} = Array{T1}(A)
to_dense(::Type{T}, A::AbstractArray{T}) where {T<:Number} = A
to_dense(::Type{T}, A::Diagonal{T}) where {T<:Number} = diagm(A.diag)

function to_dense(::Type{M}) where {M<:Union{Diagonal,SparseMatrixCSC}}
T = M
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct LindbladJump{
T2,
RNGType<:AbstractRNG,
RandT,
CT<:AbstractVector,
CT<:AbstractArray,
WT<:AbstractVector,
JTT<:AbstractVector,
JWT<:AbstractVector,
Expand Down
23 changes: 14 additions & 9 deletions src/time_evolution/mcsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ function _mcsolve_output_func(sol, i)
return (sol, false)
end

function _normalize_state!(u, dims, normalize_states)
function _normalize_state!(u, dims, normalize_states, type)
getVal(normalize_states) && normalize!(u)
return QuantumObject(u, Ket(), dims)
return QuantumObject(u, type(), dims)
end

function _mcsolve_make_Heff_QobjEvo(H::QuantumObject, c_ops)
Expand Down Expand Up @@ -110,15 +110,15 @@ If the environmental measurements register a quantum jump, the wave function und
"""
function mcsolveProblem(
H::Union{AbstractQuantumObject{Operator},Tuple},
ψ0::QuantumObject{Ket},
ψ0::QuantumObject{X},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We seldom use X here in Julia. Usually will be T or ST to represent (State) Type.

Can you also change it for the other places ?

tlist::AbstractVector,
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
params = NullParameters(),
rng::AbstractRNG = default_rng(),
jump_callback::TJC = ContinuousLindbladJumpCallback(),
kwargs...,
) where {TJC<:LindbladJumpCallbackType}
) where {TJC<:LindbladJumpCallbackType,X<:Union{Ket,Operator}}
haskey(kwargs, :save_idxs) &&
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))

Expand Down Expand Up @@ -221,7 +221,7 @@ If the environmental measurements register a quantum jump, the wave function und
"""
function mcsolveEnsembleProblem(
H::Union{AbstractQuantumObject{Operator},Tuple},
ψ0::QuantumObject{Ket},
ψ0::QuantumObject{X},
tlist::AbstractVector,
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
Expand All @@ -234,7 +234,7 @@ function mcsolveEnsembleProblem(
prob_func::Union{Function,Nothing} = nothing,
output_func::Union{Tuple,Nothing} = nothing,
kwargs...,
) where {TJC<:LindbladJumpCallbackType}
) where {TJC<:LindbladJumpCallbackType,X<:Union{Ket,Operator}}
_prob_func = isnothing(prob_func) ? _ensemble_dispatch_prob_func(rng, ntraj, tlist, _mcsolve_prob_func) : prob_func
_output_func =
output_func isa Nothing ?
Expand All @@ -261,6 +261,7 @@ function mcsolveEnsembleProblem(
ensemble_prob = TimeEvolutionProblem(
EnsembleProblem(prob_mc.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false),
prob_mc.times,
X,
prob_mc.dimensions,
(progr = _output_func[2], channel = _output_func[3]),
)
Expand Down Expand Up @@ -358,7 +359,7 @@ If the environmental measurements register a quantum jump, the wave function und
"""
function mcsolve(
H::Union{AbstractQuantumObject{Operator},Tuple},
ψ0::QuantumObject{Ket},
ψ0::QuantumObject{X},
tlist::AbstractVector,
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
alg::AbstractODEAlgorithm = DP5(),
Expand All @@ -374,7 +375,7 @@ function mcsolve(
keep_runs_results::Union{Val,Bool} = Val(false),
normalize_states::Union{Val,Bool} = Val(true),
kwargs...,
) where {TJC<:LindbladJumpCallbackType}
) where {TJC<:LindbladJumpCallbackType} where {X<:Union{Ket,Operator}}
ens_prob_mc = mcsolveEnsembleProblem(
H,
ψ0,
Expand Down Expand Up @@ -415,7 +416,11 @@ function mcsolve(
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP

# stack to transform Vector{Vector{QuantumObject}} -> Matrix{QuantumObject}
states_all = stack(map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol)), dims = 1)
# states_all = stack(map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol)), dims = 1)
states_all = stack(
map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states, ens_prob_mc.states_type), eachindex(sol)),
dims = 1,
)

col_times = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_times, eachindex(sol))
col_which = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_which, eachindex(sol))
Expand Down
72 changes: 47 additions & 25 deletions src/time_evolution/mesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ _mesolve_make_L_QobjEvo(H::Union{QuantumObjectEvolution,Tuple}, c_ops) = liouvil
_mesolve_make_L_QobjEvo(H::Nothing, c_ops::Nothing) = throw(ArgumentError("Both H and
c_ops are Nothing. You are probably running the wrong function."))

function _gen_mesolve_solution(sol, times, dimensions, isoperket::Val)
if getVal(isoperket)
ρt = map(ϕ -> QuantumObject(ϕ, type = OperatorKet(), dims = dimensions), sol.u)
function _gen_mesolve_solution(sol, prob::TimeEvolutionProblem{X}) where {X<:Union{Operator,OperatorKet,SuperOperator}}
if X() == Operator()
ρt = map(ϕ -> QuantumObject(vec2mat(ϕ), type = X(), dims = prob.dimensions), sol.u)
else
ρt = map(ϕ -> QuantumObject(vec2mat(ϕ), type = Operator(), dims = dimensions), sol.u)
ρt = map(ϕ -> QuantumObject(ϕ, type = X(), dims = prob.dimensions), sol.u)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, just use prob.states_type

Suggested change
ρt = map-> QuantumObject(ϕ, type = X(), dims = prob.dimensions), sol.u)
ρt = map-> QuantumObject(ϕ, type = prob.states_type, dims = prob.dimensions), sol.u)

end

kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility

return TimeEvolutionSol(
times,
prob.times,
sol.t,
ρt,
_get_expvals(sol, SaveFuncMESolve),
Expand Down Expand Up @@ -86,8 +86,8 @@ function mesolveProblem(
progress_bar::Union{Val,Bool} = Val(true),
inplace::Union{Val,Bool} = Val(true),
kwargs...,
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}}
(isoper(H) && isket(ψ0) && isnothing(c_ops)) && return sesolveProblem(
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}}
(isoper(H) && (isket(ψ0) || isoper(ψ0)) && isnothing(c_ops)) && return sesolveProblem(
H,
ψ0,
tlist;
Expand All @@ -107,11 +107,27 @@ function mesolveProblem(
check_dimensions(L_evo, ψ0)

T = Base.promote_eltype(L_evo, ψ0)
ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type
to_dense(_complex_float_type(T), copy(ψ0.data))
# ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type
# to_dense(_complex_float_type(T), copy(ψ0.data))
# else
# to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
# end
if isoper(ψ0)
ρ0 = to_dense(_complex_float_type(T), mat2vec(ψ0.data))
state_type = Operator()
elseif isoperket(ψ0)
ρ0 = to_dense(_complex_float_type(T), copy(ψ0.data))
state_type = OperatorKet()
elseif isket(ψ0)
ρ0 = to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
state_type = Operator()
elseif issuper(ψ0)
ρ0 = to_dense(_complex_float_type(T), copy(ψ0.data))
state_type = SuperOperator()
else
to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
throw(ArgumentError("Unsupported state type for ψ0 in mesolveProblem."))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant cause there is no method for other state type (as you defined above).

end

L = cache_operator(L_evo.data, ρ0)

kwargs2 = _merge_saveat(tlist, e_ops, DEFAULT_ODE_SOLVER_OPTIONS; kwargs...)
Expand All @@ -122,7 +138,7 @@ function mesolveProblem(

prob = ODEProblem{getVal(inplace),FullSpecialize}(L, ρ0, tspan, params; kwargs4...)

return TimeEvolutionProblem(prob, tlist, L_evo.dimensions, (isoperket = Val(isoperket(ψ0)),))
return TimeEvolutionProblem(prob, tlist, state_type, L_evo.dimensions)#, (isoperket = Val(isoperket(ψ0)),))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove the comment directly

end

@doc raw"""
Expand Down Expand Up @@ -188,8 +204,8 @@ function mesolve(
progress_bar::Union{Val,Bool} = Val(true),
inplace::Union{Val,Bool} = Val(true),
kwargs...,
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}}
(isoper(H) && isket(ψ0) && isnothing(c_ops)) && return sesolve(
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}}
(isoper(H) && (isket(ψ0) || isoper(ψ0)) && isnothing(c_ops)) && return sesolve(
H,
ψ0,
tlist;
Expand Down Expand Up @@ -230,7 +246,7 @@ end
function mesolve(prob::TimeEvolutionProblem, alg::AbstractODEAlgorithm = DP5(); kwargs...)
sol = solve(prob.prob, alg; kwargs...)

return _gen_mesolve_solution(sol, prob.times, prob.dimensions, prob.kwargs.isoperket)
return _gen_mesolve_solution(sol, prob)#, prob.kwargs.isoperket)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove the comment directly

end

@doc raw"""
Expand Down Expand Up @@ -298,8 +314,8 @@ function mesolve_map(
params::Union{NullParameters,Tuple} = NullParameters(),
progress_bar::Union{Val,Bool} = Val(true),
kwargs...,
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}}
(isoper(H) && all(isket, ψ0) && isnothing(c_ops)) && return sesolve_map(
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}}
(isoper(H) && (all(isket, ψ0) || all(isoper, ψ0)) && isnothing(c_ops)) && return sesolve_map(
H,
ψ0,
tlist;
Expand All @@ -315,10 +331,16 @@ function mesolve_map(
# Convert to appropriate format based on state type
ψ0_iter = map(ψ0) do state
T = _complex_float_type(eltype(state))
if isoperket(state)
to_dense(T, copy(state.data))
if isoper(state)
to_dense(_complex_float_type(T), mat2vec(state.data))
elseif isoperket(state)
to_dense(_complex_float_type(T), copy(state.data))
elseif isket(state)
to_dense(_complex_float_type(T), mat2vec(ket2dm(state).data))
elseif issuper(state)
to_dense(_complex_float_type(T), copy(state.data))
else
to_dense(T, mat2vec(ket2dm(state).data))
throw(ArgumentError("Unsupported state type for ψ0 in mesolveProblem."))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, this is redundant

end
end
if params isa NullParameters
Expand Down Expand Up @@ -347,7 +369,7 @@ mesolve_map(
tlist::AbstractVector,
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
kwargs...,
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}} =
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}} =
mesolve_map(H, [ψ0], tlist, c_ops; kwargs...)

# this method is for advanced usage
Expand All @@ -357,14 +379,14 @@ mesolve_map(
#
# Return: An array of TimeEvolutionSol objects with the size same as the given iter.
function mesolve_map(
prob::TimeEvolutionProblem{<:ODEProblem},
prob::TimeEvolutionProblem{StateOpType,<:ODEProblem},
iter::AbstractArray,
alg::AbstractODEAlgorithm = DP5(),
ensemblealg::EnsembleAlgorithm = EnsembleThreads();
prob_func::Union{Function,Nothing} = nothing,
output_func::Union{Tuple,Nothing} = nothing,
progress_bar::Union{Val,Bool} = Val(true),
)
) where {StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}}
# generate ensemble problem
ntraj = length(iter)
_prob_func = isnothing(prob_func) ? (prob, i, repeat) -> _se_me_map_prob_func(prob, i, repeat, iter) : prob_func
Expand All @@ -380,14 +402,14 @@ function mesolve_map(
ens_prob = TimeEvolutionProblem(
EnsembleProblem(prob.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false),
prob.times,
StateOpType(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
StateOpType(),
prob.state_type,

prob.dimensions,
(progr = _output_func[2], channel = _output_func[3], isoperket = prob.kwargs.isoperket),
(progr = _output_func[2], channel = _output_func[3]),
)

sol = _ensemble_dispatch_solve(ens_prob, alg, ensemblealg, ntraj)

# handle solution and make it become an Array of TimeEvolutionSol
sol_vec =
[_gen_mesolve_solution(sol[:, i], prob.times, prob.dimensions, prob.kwargs.isoperket) for i in eachindex(sol)] # map is type unstable
sol_vec = [_gen_mesolve_solution(sol[:, i], prob) for i in eachindex(sol)] # map is type unstable
return reshape(sol_vec, size(iter))
end
Loading