Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ext/QuantumToolboxCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module QuantumToolboxCUDAExt

using QuantumToolbox
import QuantumToolbox: _convert_u0
import CUDA: cu, CuArray
import CUDA.CUSPARSE: CuSparseVector, CuSparseMatrixCSC, CuSparseMatrixCSR
import SparseArrays: SparseVector, SparseMatrixCSC
Expand Down Expand Up @@ -89,4 +90,7 @@ _change_eltype(::Type{T}, ::Val{32}) where {T<:AbstractFloat} = Float32
_change_eltype(::Type{Complex{T}}, ::Val{64}) where {T<:Union{Int,AbstractFloat}} = ComplexF64
_change_eltype(::Type{Complex{T}}, ::Val{32}) where {T<:Union{Int,AbstractFloat}} = ComplexF32

# make sure u0 in time evolution is dense vector and has complex element type
_convert_u0(u0::Union{CuArray{T},CuSparseVector{T}}) where {T<:Number} = convert(CuArray{complex(T)}, u0)

end
9 changes: 6 additions & 3 deletions src/steadystate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,14 @@ function steadystate(
(H.dims != ψ0.dims) && throw(DimensionMismatch("The two quantum objects are not of the same Hilbert dimension."))

N = prod(H.dims)
u0 = mat2vec(ket2dm(ψ0).data)
u0 = _convert_u0(mat2vec(ket2dm(ψ0).data))

Ftype = real(eltype(u0))
Tspan = (convert(Ftype, 0), convert(Ftype, tspan)) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

L = MatrixOperator(liouvillian(H, c_ops).data)

prob = ODEProblem{true}(L, u0, (0.0, tspan))
prob = ODEProblem{true}(L, u0, Tspan)
sol = solve(
prob,
solver.alg;
Expand All @@ -109,7 +113,6 @@ function steadystate(
)

ρss = reshape(sol.u[end], N, N)
ρss = (ρss + ρss') / 2 # Hermitianize
return QuantumObject(ρss, Operator, H.dims)
end

Expand Down
2 changes: 1 addition & 1 deletion src/time_evolution/mcsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ function mcsolveProblem(
c_ops isa Nothing &&
throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead."))

t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
t_l = convert(Vector{real(eltype(ψ0))}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

H_eff = H - 1im * mapreduce(op -> op' * op, +, c_ops) / 2

Expand Down
4 changes: 2 additions & 2 deletions src/time_evolution/mesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ function mesolveProblem(
is_time_dependent = !(H_t isa Nothing)
progress_bar_val = makeVal(progress_bar)

t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
ρ0 = _convert_u0(mat2vec(ket2dm(ψ0).data))

ρ0 = mat2vec(ket2dm(ψ0).data)
t_l = convert(Vector{real(eltype(ρ0))}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

L = liouvillian(H, c_ops).data
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))
Expand Down
4 changes: 2 additions & 2 deletions src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ function sesolveProblem(
is_time_dependent = !(H_t isa Nothing)
progress_bar_val = makeVal(progress_bar)

t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
ϕ0 = _convert_u0(get_data(ψ0))

ϕ0 = get_data(ψ0)
t_l = convert(Vector{real(eltype(ϕ0))}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

U = -1im * get_data(H)
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))
Expand Down
3 changes: 3 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,6 @@ _non_static_array_warning(argname, arg::AbstractVector{T}) where {T} =
@warn "The argument $argname should be a Tuple or a StaticVector for better performance. Try to use `$argname = $(Tuple(arg))` or `$argname = SVector(" *
join(arg, ", ") *
")` instead of `$argname = $arg`." maxlog = 1

# make sure u0 in time evolution is dense vector and has complex element type
_convert_u0(u0::AbstractVector{T}) where {T<:Number} = convert(Vector{complex(T)}, u0)