Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ext/QuantumToolboxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,10 @@
_change_eltype(::Type{Complex{T}}, ::Val{64}) where {T<:Union{Int,AbstractFloat}} = ComplexF64
_change_eltype(::Type{Complex{T}}, ::Val{32}) where {T<:Union{Int,AbstractFloat}} = ComplexF32

sparse_to_dense(::Type{T}, A::CuArray{T}) where {T<:Number} = A
sparse_to_dense(::Type{T1}, A::CuArray{T2}) where {T1<:Number,T2<:Number} = CuArray{T1}(A)
sparse_to_dense(::Type{T}, A::CuSparseVector) where {T<:Number} = CuArray{T}(A)
sparse_to_dense(::Type{T}, A::CuSparseMatrixCSC) where {T<:Number} = CuArray{T}(A)
sparse_to_dense(::Type{T}, A::CuSparseMatrixCSR) where {T<:Number} = CuArray{T}(A)

Check warning on line 96 in ext/QuantumToolboxCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuantumToolboxCUDAExt.jl#L92-L96

Added lines #L92 - L96 were not covered by tests

end
6 changes: 5 additions & 1 deletion src/qobj/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,16 @@
Converts a sparse QuantumObject to a dense QuantumObject.
"""
sparse_to_dense(A::QuantumObject{<:AbstractVecOrMat}) = QuantumObject(sparse_to_dense(A.data), A.type, A.dims)
sparse_to_dense(A::MT) where {MT<:AbstractSparseMatrix} = Array(A)
sparse_to_dense(A::MT) where {MT<:AbstractSparseArray} = Array(A)
for op in (:Transpose, :Adjoint)
@eval sparse_to_dense(A::$op{T,<:AbstractSparseMatrix}) where {T<:BlasFloat} = Array(A)
end
sparse_to_dense(A::MT) where {MT<:AbstractArray} = A

sparse_to_dense(::Type{T}, A::AbstractSparseArray) where {T<:Number} = Array{T}(A)

Check warning on line 117 in src/qobj/functions.jl

View check run for this annotation

Codecov / codecov/patch

src/qobj/functions.jl#L117

Added line #L117 was not covered by tests
sparse_to_dense(::Type{T1}, A::AbstractArray{T2}) where {T1<:Number,T2<:Number} = Array{T1}(A)
sparse_to_dense(::Type{T}, A::AbstractArray{T}) where {T<:Number} = A

function sparse_to_dense(::Type{M}) where {M<:SparseMatrixCSC}
T = M
par = T.parameters
Expand Down
4 changes: 4 additions & 0 deletions src/qobj/quantum_object.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,7 @@ SparseArrays.SparseMatrixCSC(A::QuantumObject{<:AbstractMatrix}) =
QuantumObject(SparseMatrixCSC(A.data), A.type, A.dims)
SparseArrays.SparseMatrixCSC{T}(A::QuantumObject{<:SparseMatrixCSC}) where {T<:Number} =
QuantumObject(SparseMatrixCSC{T}(A.data), A.type, A.dims)

# functions for getting Float or Complex element type
_FType(::QuantumObject{<:AbstractArray{T}}) where {T<:Number} = _FType(T)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to put them together with the othe definitions in the utilities file.

_CType(::QuantumObject{<:AbstractArray{T}}) where {T<:Number} = _CType(T)
7 changes: 4 additions & 3 deletions src/steadystate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,12 @@ function steadystate(
(H.dims != ψ0.dims) && throw(DimensionMismatch("The two quantum objects are not of the same Hilbert dimension."))

N = prod(H.dims)
u0 = mat2vec(ket2dm(ψ0).data)
u0 = sparse_to_dense(_CType(ψ0), mat2vec(ket2dm(ψ0).data))

L = MatrixOperator(liouvillian(H, c_ops).data)

prob = ODEProblem{true}(L, u0, (0.0, tspan))
ftype = _FType(ψ0)
prob = ODEProblem{true}(L, u0, (ftype(0), ftype(tspan))) # Convert tspan to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl
sol = solve(
prob,
solver.alg;
Expand All @@ -109,7 +111,6 @@ function steadystate(
)

ρss = reshape(sol.u[end], N, N)
ρss = (ρss + ρss') / 2 # Hermitianize
return QuantumObject(ρss, Operator, H.dims)
end

Expand Down
2 changes: 1 addition & 1 deletion src/time_evolution/mcsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ function mcsolveProblem(
c_ops isa Nothing &&
throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead."))

t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

H_eff = H - 1im * mapreduce(op -> op' * op, +, c_ops) / 2

Expand Down
4 changes: 2 additions & 2 deletions src/time_evolution/mesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ function mesolveProblem(
is_time_dependent = !(H_t isa Nothing)
progress_bar_val = makeVal(progress_bar)

t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
ρ0 = sparse_to_dense(_CType(ψ0), mat2vec(ket2dm(ψ0).data)) # Convert it to dense vector with complex element type

ρ0 = mat2vec(ket2dm(ψ0).data)
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

L = liouvillian(H, c_ops).data
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))
Expand Down
4 changes: 2 additions & 2 deletions src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ function sesolveProblem(
is_time_dependent = !(H_t isa Nothing)
progress_bar_val = makeVal(progress_bar)

t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl
ϕ0 = sparse_to_dense(_CType(ψ0), get_data(ψ0)) # Convert it to dense vector with complex element type

ϕ0 = get_data(ψ0)
t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

U = -1im * get_data(H)
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))
Expand Down
16 changes: 16 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,19 @@
@warn "The argument $argname should be a Tuple or a StaticVector for better performance. Try to use `$argname = $(Tuple(arg))` or `$argname = SVector(" *
join(arg, ", ") *
")` instead of `$argname = $arg`." maxlog = 1

# functions for getting Float or Complex element type
_FType(::AbstractArray{T}) where {T<:Number} = _FType(T)
_FType(::Type{Int32}) = Float32

Check warning on line 71 in src/utilities.jl

View check run for this annotation

Codecov / codecov/patch

src/utilities.jl#L70-L71

Added lines #L70 - L71 were not covered by tests
_FType(::Type{Int64}) = Float64
_FType(::Type{Float32}) = Float32
_FType(::Type{Float64}) = Float64
_FType(::Type{ComplexF32}) = Float32

Check warning on line 75 in src/utilities.jl

View check run for this annotation

Codecov / codecov/patch

src/utilities.jl#L73-L75

Added lines #L73 - L75 were not covered by tests
_FType(::Type{ComplexF64}) = Float64
_CType(::AbstractArray{T}) where {T<:Number} = _CType(T)
_CType(::Type{Int32}) = ComplexF32

Check warning on line 78 in src/utilities.jl

View check run for this annotation

Codecov / codecov/patch

src/utilities.jl#L77-L78

Added lines #L77 - L78 were not covered by tests
_CType(::Type{Int64}) = ComplexF64
_CType(::Type{Float32}) = ComplexF32
_CType(::Type{Float64}) = ComplexF64
_CType(::Type{ComplexF32}) = ComplexF32

Check warning on line 82 in src/utilities.jl

View check run for this annotation

Codecov / codecov/patch

src/utilities.jl#L80-L82

Added lines #L80 - L82 were not covered by tests
_CType(::Type{ComplexF64}) = ComplexF64
8 changes: 7 additions & 1 deletion test/core-test/time_evolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
"reltol = $(sol.reltol)\n"

@testset "Type Inference sesolve" begin
@inferred sesolveProblem(H, psi0, t_l)
@inferred sesolveProblem(H, psi0, t_l, progress_bar = Val(false))
@inferred sesolveProblem(H, psi0, [0, 10], progress_bar = Val(false))
@inferred sesolveProblem(H, Qobj(zeros(Int64, N * 2); dims = (N, 2)), t_l, progress_bar = Val(false))
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = Val(false))
@inferred sesolve(H, psi0, t_l, progress_bar = Val(false))
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
Expand Down Expand Up @@ -91,6 +93,8 @@

@testset "Type Inference mesolve" begin
@inferred mesolveProblem(H, psi0, t_l, c_ops, e_ops = e_ops, progress_bar = Val(false))
@inferred mesolveProblem(H, psi0, [0, 10], c_ops, e_ops = e_ops, progress_bar = Val(false))
@inferred mesolveProblem(H, Qobj(zeros(Int64, N)), t_l, c_ops, e_ops = e_ops, progress_bar = Val(false))
@inferred mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, progress_bar = Val(false))
@inferred mesolve(H, psi0, t_l, c_ops, progress_bar = Val(false))
@inferred mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
Expand All @@ -108,6 +112,8 @@
)
@inferred mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = Val(false))
@inferred mcsolve(H, psi0, t_l, c_ops, n_traj = 500, progress_bar = Val(true))
@inferred mcsolve(H, psi0, [0, 10], c_ops, n_traj = 500, progress_bar = Val(false))
@inferred mcsolve(H, Qobj(zeros(Int64, N)), t_l, c_ops, n_traj = 500, progress_bar = Val(false))
end
end

Expand Down