diff --git a/.gitignore b/.gitignore index b07886a94..27bcfc910 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ Manifest.toml benchmarks/benchmarks_output.json .ipynb_checkpoints -*.ipynb \ No newline at end of file +.devcontainer/* +*.ipynb diff --git a/src/qobj/functions.jl b/src/qobj/functions.jl index 1d2cf5689..219bd5854 100644 --- a/src/qobj/functions.jl +++ b/src/qobj/functions.jl @@ -119,10 +119,12 @@ Converts a sparse QuantumObject to a dense QuantumObject. to_dense(A::QuantumObject) = QuantumObject(to_dense(A.data), A.type, A.dimensions) to_dense(A::MT) where {MT<:AbstractSparseArray} = Array(A) to_dense(A::MT) where {MT<:AbstractArray} = A +to_dense(A::Diagonal) = diagm(A.diag) to_dense(::Type{T}, A::AbstractSparseArray) where {T<:Number} = Array{T}(A) to_dense(::Type{T1}, A::AbstractArray{T2}) where {T1<:Number,T2<:Number} = Array{T1}(A) to_dense(::Type{T}, A::AbstractArray{T}) where {T<:Number} = A +to_dense(::Type{T}, A::Diagonal{T}) where {T<:Number} = diagm(A.diag) function to_dense(::Type{M}) where {M<:Union{Diagonal,SparseMatrixCSC}} T = M diff --git a/src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl b/src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl index e7c9837cc..499140fff 100644 --- a/src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl +++ b/src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl @@ -18,7 +18,7 @@ struct LindbladJump{ T2, RNGType<:AbstractRNG, RandT, - CT<:AbstractVector, + CT<:AbstractArray, WT<:AbstractVector, JTT<:AbstractVector, JWT<:AbstractVector, diff --git a/src/time_evolution/mcsolve.jl b/src/time_evolution/mcsolve.jl index 8630810c7..026d0107e 100644 --- a/src/time_evolution/mcsolve.jl +++ b/src/time_evolution/mcsolve.jl @@ -20,9 +20,9 @@ function _mcsolve_output_func(sol, i) return (sol, false) end -function _normalize_state!(u, dims, normalize_states) +function _normalize_state!(u, dims, normalize_states, type) getVal(normalize_states) && normalize!(u) - return QuantumObject(u, Ket(), dims) + return QuantumObject(u, type, dims) end function _mcsolve_make_Heff_QobjEvo(H::QuantumObject, c_ops) @@ -110,7 +110,7 @@ If the environmental measurements register a quantum jump, the wave function und """ function mcsolveProblem( H::Union{AbstractQuantumObject{Operator},Tuple}, - ψ0::QuantumObject{Ket}, + ψ0::QuantumObject{ST}, tlist::AbstractVector, c_ops::Union{Nothing,AbstractVector,Tuple} = nothing; e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, @@ -118,7 +118,7 @@ function mcsolveProblem( rng::AbstractRNG = default_rng(), jump_callback::TJC = ContinuousLindbladJumpCallback(), kwargs..., -) where {TJC<:LindbladJumpCallbackType} +) where {TJC<:LindbladJumpCallbackType,ST<:Union{Ket,Operator}} haskey(kwargs, :save_idxs) && throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox.")) @@ -221,7 +221,7 @@ If the environmental measurements register a quantum jump, the wave function und """ function mcsolveEnsembleProblem( H::Union{AbstractQuantumObject{Operator},Tuple}, - ψ0::QuantumObject{Ket}, + ψ0::QuantumObject{ST}, tlist::AbstractVector, c_ops::Union{Nothing,AbstractVector,Tuple} = nothing; e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, @@ -234,7 +234,7 @@ function mcsolveEnsembleProblem( prob_func::Union{Function,Nothing} = nothing, output_func::Union{Tuple,Nothing} = nothing, kwargs..., -) where {TJC<:LindbladJumpCallbackType} +) where {TJC<:LindbladJumpCallbackType,ST<:Union{Ket,Operator}} _prob_func = isnothing(prob_func) ? _ensemble_dispatch_prob_func(rng, ntraj, tlist, _mcsolve_prob_func) : prob_func _output_func = output_func isa Nothing ? @@ -261,6 +261,7 @@ function mcsolveEnsembleProblem( ensemble_prob = TimeEvolutionProblem( EnsembleProblem(prob_mc.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false), prob_mc.times, + ST(), prob_mc.dimensions, (progr = _output_func[2], channel = _output_func[3]), ) @@ -358,7 +359,7 @@ If the environmental measurements register a quantum jump, the wave function und """ function mcsolve( H::Union{AbstractQuantumObject{Operator},Tuple}, - ψ0::QuantumObject{Ket}, + ψ0::QuantumObject{ST}, tlist::AbstractVector, c_ops::Union{Nothing,AbstractVector,Tuple} = nothing; alg::AbstractODEAlgorithm = DP5(), @@ -374,7 +375,7 @@ function mcsolve( keep_runs_results::Union{Val,Bool} = Val(false), normalize_states::Union{Val,Bool} = Val(true), kwargs..., -) where {TJC<:LindbladJumpCallbackType} +) where {TJC<:LindbladJumpCallbackType} where {ST<:Union{Ket,Operator}} ens_prob_mc = mcsolveEnsembleProblem( H, ψ0, @@ -414,8 +415,11 @@ function mcsolve( _expvals_sol_1 isa Nothing ? nothing : map(i -> _get_expvals(sol[:, i], SaveFuncMCSolve), eachindex(sol)) expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP - # stack to transform Vector{Vector{QuantumObject}} -> Matrix{QuantumObject} - states_all = stack(map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol)), dims = 1) + + states_all = stack( + map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states, [ens_prob_mc.states_type]), eachindex(sol)), # Unsure why ens_prob_mc.states_type needs to be in an array but the other two arguments don't! + dims = 1, + ) col_times = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_times, eachindex(sol)) col_which = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_which, eachindex(sol)) diff --git a/src/time_evolution/mesolve.jl b/src/time_evolution/mesolve.jl index b9aee6595..ebc4481cb 100644 --- a/src/time_evolution/mesolve.jl +++ b/src/time_evolution/mesolve.jl @@ -6,17 +6,17 @@ _mesolve_make_L_QobjEvo(H::Union{QuantumObjectEvolution,Tuple}, c_ops) = liouvil _mesolve_make_L_QobjEvo(H::Nothing, c_ops::Nothing) = throw(ArgumentError("Both H and c_ops are Nothing. You are probably running the wrong function.")) -function _gen_mesolve_solution(sol, times, dimensions, isoperket::Val) - if getVal(isoperket) - ρt = map(ϕ -> QuantumObject(ϕ, type = OperatorKet(), dims = dimensions), sol.u) +function _gen_mesolve_solution(sol, prob::TimeEvolutionProblem{ST}) where {ST<:Union{Operator,OperatorKet,SuperOperator}} + if prob.states_type == Operator + ρt = map(ϕ -> QuantumObject(vec2mat(ϕ), type = prob.states_type, dims = prob.dimensions), sol.u) else - ρt = map(ϕ -> QuantumObject(vec2mat(ϕ), type = Operator(), dims = dimensions), sol.u) + ρt = map(ϕ -> QuantumObject(ϕ, type = prob.states_type, dims = prob.dimensions), sol.u) end kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility return TimeEvolutionSol( - times, + prob.times, sol.t, ρt, _get_expvals(sol, SaveFuncMESolve), @@ -86,8 +86,8 @@ function mesolveProblem( progress_bar::Union{Val,Bool} = Val(true), inplace::Union{Val,Bool} = Val(true), kwargs..., -) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}} - (isoper(H) && isket(ψ0) && isnothing(c_ops)) && return sesolveProblem( +) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}} + (isoper(H) && (isket(ψ0) || isoper(ψ0)) && isnothing(c_ops)) && return sesolveProblem( H, ψ0, tlist; @@ -107,11 +107,25 @@ function mesolveProblem( check_dimensions(L_evo, ψ0) T = Base.promote_eltype(L_evo, ψ0) - ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type - to_dense(_complex_float_type(T), copy(ψ0.data)) - else - to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data)) + # ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type + # to_dense(_complex_float_type(T), copy(ψ0.data)) + # else + # to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data)) + # end + if isoper(ψ0) + ρ0 = to_dense(_complex_float_type(T), mat2vec(ψ0.data)) + state_type = Operator() + elseif isoperket(ψ0) + ρ0 = to_dense(_complex_float_type(T), copy(ψ0.data)) + state_type = OperatorKet() + elseif isket(ψ0) + ρ0 = to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data)) + state_type = Operator() + elseif issuper(ψ0) + ρ0 = to_dense(_complex_float_type(T), copy(ψ0.data)) + state_type = SuperOperator() end + L = cache_operator(L_evo.data, ρ0) kwargs2 = _merge_saveat(tlist, e_ops, DEFAULT_ODE_SOLVER_OPTIONS; kwargs...) @@ -122,7 +136,7 @@ function mesolveProblem( prob = ODEProblem{getVal(inplace),FullSpecialize}(L, ρ0, tspan, params; kwargs4...) - return TimeEvolutionProblem(prob, tlist, L_evo.dimensions, (isoperket = Val(isoperket(ψ0)),)) + return TimeEvolutionProblem(prob, tlist, state_type, L_evo.dimensions) end @doc raw""" @@ -188,8 +202,8 @@ function mesolve( progress_bar::Union{Val,Bool} = Val(true), inplace::Union{Val,Bool} = Val(true), kwargs..., -) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}} - (isoper(H) && isket(ψ0) && isnothing(c_ops)) && return sesolve( +) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}} + (isoper(H) && (isket(ψ0) || isoper(ψ0)) && isnothing(c_ops)) && return sesolve( H, ψ0, tlist; @@ -230,7 +244,7 @@ end function mesolve(prob::TimeEvolutionProblem, alg::AbstractODEAlgorithm = DP5(); kwargs...) sol = solve(prob.prob, alg; kwargs...) - return _gen_mesolve_solution(sol, prob.times, prob.dimensions, prob.kwargs.isoperket) + return _gen_mesolve_solution(sol, prob) end @doc raw""" @@ -298,8 +312,8 @@ function mesolve_map( params::Union{NullParameters,Tuple} = NullParameters(), progress_bar::Union{Val,Bool} = Val(true), kwargs..., -) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}} - (isoper(H) && all(isket, ψ0) && isnothing(c_ops)) && return sesolve_map( +) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}} + (isoper(H) && (all(isket, ψ0) || all(isoper, ψ0)) && isnothing(c_ops)) && return sesolve_map( H, ψ0, tlist; @@ -315,10 +329,14 @@ function mesolve_map( # Convert to appropriate format based on state type ψ0_iter = map(ψ0) do state T = _complex_float_type(eltype(state)) - if isoperket(state) - to_dense(T, copy(state.data)) - else - to_dense(T, mat2vec(ket2dm(state).data)) + if isoper(state) + to_dense(_complex_float_type(T), mat2vec(state.data)) + elseif isoperket(state) + to_dense(_complex_float_type(T), copy(state.data)) + elseif isket(state) + to_dense(_complex_float_type(T), mat2vec(ket2dm(state).data)) + elseif issuper(state) + to_dense(_complex_float_type(T), copy(state.data)) end end if params isa NullParameters @@ -347,7 +365,7 @@ mesolve_map( tlist::AbstractVector, c_ops::Union{Nothing,AbstractVector,Tuple} = nothing; kwargs..., -) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}} = +) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}} = mesolve_map(H, [ψ0], tlist, c_ops; kwargs...) # this method is for advanced usage @@ -357,14 +375,14 @@ mesolve_map( # # Return: An array of TimeEvolutionSol objects with the size same as the given iter. function mesolve_map( - prob::TimeEvolutionProblem{<:ODEProblem}, + prob::TimeEvolutionProblem{StateOpType, <:AbstractDimensions, <:ODEProblem}, iter::AbstractArray, alg::AbstractODEAlgorithm = DP5(), ensemblealg::EnsembleAlgorithm = EnsembleThreads(); prob_func::Union{Function,Nothing} = nothing, output_func::Union{Tuple,Nothing} = nothing, progress_bar::Union{Val,Bool} = Val(true), -) +) where {StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}} # generate ensemble problem ntraj = length(iter) _prob_func = isnothing(prob_func) ? (prob, i, repeat) -> _se_me_map_prob_func(prob, i, repeat, iter) : prob_func @@ -380,14 +398,14 @@ function mesolve_map( ens_prob = TimeEvolutionProblem( EnsembleProblem(prob.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false), prob.times, + StateOpType(), prob.dimensions, - (progr = _output_func[2], channel = _output_func[3], isoperket = prob.kwargs.isoperket), + (progr = _output_func[2], channel = _output_func[3]), ) sol = _ensemble_dispatch_solve(ens_prob, alg, ensemblealg, ntraj) # handle solution and make it become an Array of TimeEvolutionSol - sol_vec = - [_gen_mesolve_solution(sol[:, i], prob.times, prob.dimensions, prob.kwargs.isoperket) for i in eachindex(sol)] # map is type unstable + sol_vec = [_gen_mesolve_solution(sol[:, i], prob) for i in eachindex(sol)] # map is type unstable return reshape(sol_vec, size(iter)) end diff --git a/src/time_evolution/sesolve.jl b/src/time_evolution/sesolve.jl index df0b2cd93..c807e774c 100644 --- a/src/time_evolution/sesolve.jl +++ b/src/time_evolution/sesolve.jl @@ -2,13 +2,13 @@ export sesolveProblem, sesolve, sesolve_map _sesolve_make_U_QobjEvo(H) = -1im * QuantumObjectEvolution(H, type = Operator()) -function _gen_sesolve_solution(sol, times, dimensions) - ψt = map(ϕ -> QuantumObject(ϕ, type = Ket(), dims = dimensions), sol.u) +function _gen_sesolve_solution(sol, prob::TimeEvolutionProblem{ST}) where {ST<:Union{Ket,Operator}} + ψt = map(ϕ -> QuantumObject(ϕ, type = prob.states_type, dims = prob.dimensions), sol.u) kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility return TimeEvolutionSol( - times, + prob.times, sol.t, ψt, _get_expvals(sol, SaveFuncSESolve), @@ -61,14 +61,14 @@ Generate the ODEProblem for the Schrödinger time evolution of a quantum system: """ function sesolveProblem( H::Union{AbstractQuantumObject{Operator},Tuple}, - ψ0::QuantumObject{Ket}, + ψ0::QuantumObject{ST}, tlist::AbstractVector; e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, params = NullParameters(), progress_bar::Union{Val,Bool} = Val(true), inplace::Union{Val,Bool} = Val(true), kwargs..., -) +) where {ST<:Union{Ket,Operator}} haskey(kwargs, :save_idxs) && throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox.")) @@ -90,7 +90,7 @@ function sesolveProblem( prob = ODEProblem{getVal(inplace),FullSpecialize}(U, ψ0, tspan, params; kwargs4...) - return TimeEvolutionProblem(prob, tlist, H_evo.dimensions) + return TimeEvolutionProblem(prob, tlist, ST(), H_evo.dimensions) end @doc raw""" @@ -138,7 +138,7 @@ Time evolution of a closed quantum system using the Schrödinger equation: """ function sesolve( H::Union{AbstractQuantumObject{Operator},Tuple}, - ψ0::QuantumObject{Ket}, + ψ0::QuantumObject{ST}, tlist::AbstractVector; alg::AbstractODEAlgorithm = Vern7(lazy = false), e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, @@ -146,7 +146,7 @@ function sesolve( progress_bar::Union{Val,Bool} = Val(true), inplace::Union{Val,Bool} = Val(true), kwargs..., -) +) where {ST<:Union{Ket,Operator}} # Move sensealg argument to solve for Enzyme.jl support. # TODO: Remove it when https://github.com/SciML/SciMLSensitivity.jl/issues/1225 is fixed. @@ -175,7 +175,7 @@ end function sesolve(prob::TimeEvolutionProblem, alg::AbstractODEAlgorithm = Vern7(lazy = false); kwargs...) sol = solve(prob.prob, alg; kwargs...) - return _gen_sesolve_solution(sol, prob.times, prob.dimensions) + return _gen_sesolve_solution(sol, prob) end @doc raw""" @@ -225,7 +225,7 @@ for each combination in the ensemble. """ function sesolve_map( H::Union{AbstractQuantumObject{Operator},Tuple}, - ψ0::AbstractVector{<:QuantumObject{Ket}}, + ψ0::AbstractVector{<:QuantumObject{ST}}, tlist::AbstractVector; alg::AbstractODEAlgorithm = Vern7(lazy = false), ensemblealg::EnsembleAlgorithm = EnsembleThreads(), @@ -233,8 +233,11 @@ function sesolve_map( params::Union{NullParameters,Tuple} = NullParameters(), progress_bar::Union{Val,Bool} = Val(true), kwargs..., -) +) where {ST<:Union{Ket,Operator}} # mapping initial states and parameters + + ψ0 = map(to_dense, ψ0) # Convert all initial states to dense vectors + ψ0_iter = map(get_data, ψ0) if params isa NullParameters iter = collect(Iterators.product(ψ0_iter, [params])) |> vec # convert nx1 Matrix into Vector @@ -255,8 +258,12 @@ function sesolve_map( return sesolve_map(prob, iter, alg, ensemblealg; progress_bar = progress_bar) end -sesolve_map(H::Union{AbstractQuantumObject{Operator},Tuple}, ψ0::QuantumObject{Ket}, tlist::AbstractVector; kwargs...) = - sesolve_map(H, [ψ0], tlist; kwargs...) +sesolve_map( + H::Union{AbstractQuantumObject{Operator},Tuple}, + ψ0::QuantumObject{ST}, + tlist::AbstractVector; + kwargs..., +) where {ST<:Union{Ket,Operator}} = sesolve_map(H, [ψ0], tlist; kwargs...) # this method is for advanced usage # User can define their own iterator structure, prob_func and output_func @@ -265,14 +272,14 @@ sesolve_map(H::Union{AbstractQuantumObject{Operator},Tuple}, ψ0::QuantumObject{ # # Return: An array of TimeEvolutionSol objects with the size same as the given iter. function sesolve_map( - prob::TimeEvolutionProblem{<:ODEProblem}, + prob::TimeEvolutionProblem{ST, <:AbstractDimensions, <:ODEProblem}, iter::AbstractArray, alg::AbstractODEAlgorithm = Vern7(lazy = false), ensemblealg::EnsembleAlgorithm = EnsembleThreads(); prob_func::Union{Function,Nothing} = nothing, output_func::Union{Tuple,Nothing} = nothing, progress_bar::Union{Val,Bool} = Val(true), -) +) where {ST<:Union{Ket,Operator}} # generate ensemble problem ntraj = length(iter) _prob_func = isnothing(prob_func) ? (prob, i, repeat) -> _se_me_map_prob_func(prob, i, repeat, iter) : prob_func @@ -288,6 +295,7 @@ function sesolve_map( ens_prob = TimeEvolutionProblem( EnsembleProblem(prob.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false), prob.times, + prob.states_type, prob.dimensions, (progr = _output_func[2], channel = _output_func[3]), ) @@ -295,6 +303,6 @@ function sesolve_map( sol = _ensemble_dispatch_solve(ens_prob, alg, ensemblealg, ntraj) # handle solution and make it become an Array of TimeEvolutionSol - sol_vec = [_gen_sesolve_solution(sol[:, i], prob.times, prob.dimensions) for i in eachindex(sol)] # map is type unstable + sol_vec = [_gen_sesolve_solution(sol[:, i], prob) for i in eachindex(sol)] # map is type unstable return reshape(sol_vec, size(iter)) end diff --git a/src/time_evolution/smesolve.jl b/src/time_evolution/smesolve.jl index 1560de45f..207ff9ba2 100644 --- a/src/time_evolution/smesolve.jl +++ b/src/time_evolution/smesolve.jl @@ -1,7 +1,14 @@ export smesolveProblem, smesolveEnsembleProblem, smesolve -_smesolve_generate_state(u, dims, isoperket::Val{false}) = QuantumObject(vec2mat(u), type = Operator(), dims = dims) -_smesolve_generate_state(u, dims, isoperket::Val{true}) = QuantumObject(u, type = OperatorKet(), dims = dims) +#_smesolve_generate_state(u, dims, isoperket::Val{false}) = QuantumObject(vec2mat(u), type = Operator(), dims = dims) +#_smesolve_generate_state(u, dims, isoperket::Val{true}) = QuantumObject(u, type = OperatorKet(), dims = dims) +function _smesolve_generate_state(u, dims, type) + if type == OperatorKet + return QuantumObject(u, type = type, dims = dims) + else + return QuantumObject(vec2mat(u), type = Operator(), dims = dims) + end +end function _smesolve_update_coeff(u, p, t, op_vec) return 2 * real(dot(op_vec, u)) #this is Tr[Sn * ρ + ρ * Sn'] @@ -146,7 +153,7 @@ function smesolveProblem( kwargs4..., ) - return TimeEvolutionProblem(prob, tlist, dims, (isoperket = Val(isoperket(ψ0)),)) + return TimeEvolutionProblem(prob, tlist,StateOpType(), dims, ()) end @doc raw""" @@ -274,8 +281,9 @@ function smesolveEnsembleProblem( ensemble_prob = TimeEvolutionProblem( EnsembleProblem(prob_sme, prob_func = _prob_func, output_func = _output_func[1], safetycopy = true), prob_sme.times, + StateOpType(), prob_sme.dimensions, - merge(prob_sme.kwargs, (progr = _output_func[2], channel = _output_func[3])), + (progr = _output_func[2], channel = _output_func[3]), ) return ensemble_prob @@ -422,7 +430,7 @@ function smesolve( # stack to transform Vector{Vector{QuantumObject}} -> Matrix{QuantumObject} states_all = stack( - map(i -> _smesolve_generate_state.(sol[:, i].u, Ref(dims), ens_prob.kwargs.isoperket), eachindex(sol)), + map(i -> _smesolve_generate_state.(sol[:, i].u, Ref(dims), [ens_prob.states_type]), eachindex(sol)), dims = 1, ) diff --git a/src/time_evolution/ssesolve.jl b/src/time_evolution/ssesolve.jl index df3372858..f56396b5a 100644 --- a/src/time_evolution/ssesolve.jl +++ b/src/time_evolution/ssesolve.jl @@ -76,7 +76,7 @@ Above, ``\hat{S}_n`` are the stochastic collapse operators and ``dW_n(t)`` is th """ function ssesolveProblem( H::Union{AbstractQuantumObject{Operator},Tuple}, - ψ0::QuantumObject{Ket}, + ψ0::QuantumObject{ST}, tlist::AbstractVector, sc_ops::Union{Nothing,AbstractVector,Tuple,AbstractQuantumObject} = nothing; e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, @@ -85,7 +85,7 @@ function ssesolveProblem( progress_bar::Union{Val,Bool} = Val(true), store_measurement::Union{Val,Bool} = Val(false), kwargs..., -) +) where {ST<:Union{Ket,Operator}} haskey(kwargs, :save_idxs) && throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox.")) @@ -142,7 +142,7 @@ function ssesolveProblem( kwargs4..., ) - return TimeEvolutionProblem(prob, tlist, dims) + return TimeEvolutionProblem(prob, tlist, ST(), dims) end @doc raw""" @@ -218,7 +218,7 @@ Above, ``\hat{S}_n`` are the stochastic collapse operators and ``dW_n(t)`` is t """ function ssesolveEnsembleProblem( H::Union{AbstractQuantumObject{Operator},Tuple}, - ψ0::QuantumObject{Ket}, + ψ0::QuantumObject{ST}, tlist::AbstractVector, sc_ops::Union{Nothing,AbstractVector,Tuple,AbstractQuantumObject} = nothing; e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, @@ -231,7 +231,7 @@ function ssesolveEnsembleProblem( progress_bar::Union{Val,Bool} = Val(true), store_measurement::Union{Val,Bool} = Val(false), kwargs..., -) +) where {ST<:Union{Ket,Operator}} _prob_func = isnothing(prob_func) ? _ensemble_dispatch_prob_func( @@ -268,6 +268,7 @@ function ssesolveEnsembleProblem( ensemble_prob = TimeEvolutionProblem( EnsembleProblem(prob_sme, prob_func = _prob_func, output_func = _output_func[1], safetycopy = true), prob_sme.times, + ST(), prob_sme.dimensions, (progr = _output_func[2], channel = _output_func[3]), ) @@ -355,7 +356,7 @@ Above, ``\hat{S}_n`` are the stochastic collapse operators and ``dW_n(t)`` is th """ function ssesolve( H::Union{AbstractQuantumObject{Operator},Tuple}, - ψ0::QuantumObject{Ket}, + ψ0::QuantumObject{ST}, tlist::AbstractVector, sc_ops::Union{Nothing,AbstractVector,Tuple,AbstractQuantumObject} = nothing; alg::Union{Nothing,AbstractSDEAlgorithm} = nothing, @@ -370,7 +371,7 @@ function ssesolve( keep_runs_results::Union{Val,Bool} = Val(false), store_measurement::Union{Val,Bool} = Val(false), kwargs..., -) +) where {ST<:Union{Ket,Operator}} ens_prob = ssesolveEnsembleProblem( H, ψ0, @@ -417,7 +418,7 @@ function ssesolve( expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP # stack to transform Vector{Vector{QuantumObject}} -> Matrix{QuantumObject} - states_all = stack(map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol)), dims = 1) + states_all = stack(map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states, [ens_prob.states_type]), eachindex(sol)), dims = 1) _m_expvals = _m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSSESolve), eachindex(sol)) diff --git a/src/time_evolution/time_evolution.jl b/src/time_evolution/time_evolution.jl index e7c3791d3..a855c34ac 100644 --- a/src/time_evolution/time_evolution.jl +++ b/src/time_evolution/time_evolution.jl @@ -19,15 +19,17 @@ A Julia constructor for handling the `ODEProblem` of the time evolution of quant - `prob::AbstractSciMLProblem`: The `ODEProblem` of the time evolution. - `times::AbstractVector`: The time list of the evolution. +- `states_type::QuantumObjectType`: The type of the quantum states during the evolution (e.g., `Ket`, `Operator`, `OperatorKet` or `SuperOperator`). - `dimensions::AbstractDimensions`: The dimensions of the Hilbert space. - `kwargs::KWT`: Generic keyword arguments. !!! note "`dims` property" For a given `prob::TimeEvolutionProblem`, `prob.dims` or `getproperty(prob, :dims)` returns its `dimensions` in the type of integer-vector. """ -struct TimeEvolutionProblem{PT<:AbstractSciMLProblem,TT<:AbstractVector,DT<:AbstractDimensions,KWT} +struct TimeEvolutionProblem{ST<:QuantumObjectType, DT<:AbstractDimensions,PT<:AbstractSciMLProblem,TT<:AbstractVector,KWT} prob::PT times::TT + states_type::ST dimensions::DT kwargs::KWT end @@ -41,7 +43,7 @@ function Base.getproperty(prob::TimeEvolutionProblem, key::Symbol) end end -TimeEvolutionProblem(prob, times, dims) = TimeEvolutionProblem(prob, times, dims, nothing) +TimeEvolutionProblem(prob, times, states_type, dims) = TimeEvolutionProblem(prob, times, states_type, dims, nothing) @doc raw""" struct TimeEvolutionSol