Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions ext/QuantumToolboxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module QuantumToolboxCUDAExt

using QuantumToolbox
using QuantumToolbox: makeVal, getVal
import QuantumToolbox: _sparse_similar
import QuantumToolbox: _sparse_similar, _convert_eltype_wordsize
import CUDA: cu, CuArray, allowscalar
import CUDA.CUSPARSE: CuSparseVector, CuSparseMatrixCSC, CuSparseMatrixCSR, AbstractCuSparseArray
import SparseArrays: SparseVector, SparseMatrixCSC
Expand Down Expand Up @@ -81,27 +81,20 @@ function cu(A::QuantumObject; word_size::Union{Val,Int} = Val(64))

return cu(A, makeVal(word_size))
end
cu(A::QuantumObject, word_size::Union{Val{32},Val{64}}) = CuArray{_change_eltype(eltype(A), word_size)}(A)
cu(A::QuantumObject, word_size::Union{Val{32},Val{64}}) = CuArray{_convert_eltype_wordsize(eltype(A), word_size)}(A)
function cu(
A::QuantumObject{ObjType,DimsType,<:SparseVector},
word_size::Union{Val{32},Val{64}},
) where {ObjType<:QuantumObjectType,DimsType<:AbstractDimensions}
return CuSparseVector{_change_eltype(eltype(A), word_size)}(A)
return CuSparseVector{_convert_eltype_wordsize(eltype(A), word_size)}(A)
end
function cu(
A::QuantumObject{ObjType,DimsType,<:SparseMatrixCSC},
word_size::Union{Val{32},Val{64}},
) where {ObjType<:QuantumObjectType,DimsType<:AbstractDimensions}
return CuSparseMatrixCSC{_change_eltype(eltype(A), word_size)}(A)
return CuSparseMatrixCSC{_convert_eltype_wordsize(eltype(A), word_size)}(A)
end

_change_eltype(::Type{T}, ::Val{64}) where {T<:Int} = Int64
_change_eltype(::Type{T}, ::Val{32}) where {T<:Int} = Int32
_change_eltype(::Type{T}, ::Val{64}) where {T<:AbstractFloat} = Float64
_change_eltype(::Type{T}, ::Val{32}) where {T<:AbstractFloat} = Float32
_change_eltype(::Type{Complex{T}}, ::Val{64}) where {T<:Union{Int,AbstractFloat}} = ComplexF64
_change_eltype(::Type{Complex{T}}, ::Val{32}) where {T<:Union{Int,AbstractFloat}} = ComplexF32

QuantumToolbox.to_dense(A::MT) where {MT<:AbstractCuSparseArray} = CuArray(A)

QuantumToolbox.to_dense(::Type{T1}, A::CuArray{T2}) where {T1<:Number,T2<:Number} = CuArray{T1}(A)
Expand Down
7 changes: 7 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,10 @@ _CType(::Type{Complex{Int32}}) = ComplexF32
_CType(::Type{Complex{Int64}}) = ComplexF64
_CType(::Type{Complex{Float32}}) = ComplexF32
_CType(::Type{Complex{Float64}}) = ComplexF64

_convert_eltype_wordsize(::Type{T}, ::Val{64}) where {T<:Int} = Int64
_convert_eltype_wordsize(::Type{T}, ::Val{32}) where {T<:Int} = Int32
_convert_eltype_wordsize(::Type{T}, ::Val{64}) where {T<:AbstractFloat} = Float64
_convert_eltype_wordsize(::Type{T}, ::Val{32}) where {T<:AbstractFloat} = Float32
_convert_eltype_wordsize(::Type{Complex{T}}, ::Val{64}) where {T<:Union{Int,AbstractFloat}} = ComplexF64
_convert_eltype_wordsize(::Type{Complex{T}}, ::Val{32}) where {T<:Union{Int,AbstractFloat}} = ComplexF32
Loading