|
1 | 1 | module QuantumToolboxCUDAExt |
2 | 2 |
|
3 | 3 | using QuantumToolbox |
| 4 | +using QuantumToolbox: makeVal, getVal |
4 | 5 | import CUDA: cu, CuArray |
5 | 6 | import CUDA.CUSPARSE: CuSparseVector, CuSparseMatrixCSC, CuSparseMatrixCSR |
6 | 7 | import SparseArrays: SparseVector, SparseMatrixCSC |
@@ -70,19 +71,26 @@ Return a new [`QuantumObject`](@ref) where `A.data` is in the type of `CUDA` arr |
70 | 71 | - `A::QuantumObject`: The [`QuantumObject`](@ref) |
71 | 72 | - `word_size::Int`: The word size of the element type of `A`, can be either `32` or `64`. Default to `64`. |
72 | 73 | """ |
73 | | -cu(A::QuantumObject; word_size::Int = 64) = |
74 | | - ((word_size == 64) || (word_size == 32)) ? cu(A, Val(word_size)) : |
75 | | - throw(DomainError(word_size, "The word size should be 32 or 64.")) |
76 | | -cu(A::QuantumObject, word_size::TW) where {TW<:Union{Val{32},Val{64}}} = |
77 | | - CuArray{_change_eltype(eltype(A), word_size)}(A) |
78 | | -cu( |
| 74 | +function cu(A::QuantumObject; word_size::Union{Val,Int} = Val(64)) |
| 75 | + _word_size = getVal(makeVal(word_size)) |
| 76 | + |
| 77 | + ((_word_size == 64) || (_word_size == 32)) || throw(DomainError(_word_size, "The word size should be 32 or 64.")) |
| 78 | + |
| 79 | + return cu(A, makeVal(word_size)) |
| 80 | +end |
| 81 | +cu(A::QuantumObject, word_size::Union{Val{32},Val{64}}) = CuArray{_change_eltype(eltype(A), word_size)}(A) |
| 82 | +function cu( |
79 | 83 | A::QuantumObject{ObjType,DimsType,<:SparseVector}, |
80 | | - word_size::TW, |
81 | | -) where {ObjType,DimsType,TW<:Union{Val{32},Val{64}}} = CuSparseVector{_change_eltype(eltype(A), word_size)}(A) |
82 | | -cu( |
| 84 | + word_size::Union{Val{32},Val{64}}, |
| 85 | +) where {ObjType<:QuantumObjectType,DimsType<:AbstractDimensions} |
| 86 | + return CuSparseVector{_change_eltype(eltype(A), word_size)}(A) |
| 87 | +end |
| 88 | +function cu( |
83 | 89 | A::QuantumObject{ObjType,DimsType,<:SparseMatrixCSC}, |
84 | | - word_size::TW, |
85 | | -) where {ObjType,DimsType,TW<:Union{Val{32},Val{64}}} = CuSparseMatrixCSC{_change_eltype(eltype(A), word_size)}(A) |
| 90 | + word_size::Union{Val{32},Val{64}}, |
| 91 | +) where {ObjType<:QuantumObjectType,DimsType<:AbstractDimensions} |
| 92 | + return CuSparseMatrixCSC{_change_eltype(eltype(A), word_size)}(A) |
| 93 | +end |
86 | 94 |
|
87 | 95 | _change_eltype(::Type{T}, ::Val{64}) where {T<:Int} = Int64 |
88 | 96 | _change_eltype(::Type{T}, ::Val{32}) where {T<:Int} = Int32 |
|
0 commit comments