Skip to content

Commit 3219cb8

Browse files
Make cuda conversion more general
1 parent acc1ddd commit 3219cb8

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

ext/QuantumToolboxCUDAExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import QuantumToolbox: _sparse_similar, _convert_eltype_wordsize
66
import CUDA: cu, CuArray, allowscalar
77
import CUDA.CUSPARSE: CuSparseVector, CuSparseMatrixCSC, CuSparseMatrixCSR, AbstractCuSparseArray
88
import SparseArrays: SparseVector, SparseMatrixCSC
9+
import CUDA.Adapt: adapt
910

1011
allowscalar(false)
1112

@@ -81,7 +82,7 @@ function cu(A::QuantumObject; word_size::Union{Val,Int} = Val(64))
8182

8283
return cu(A, makeVal(word_size))
8384
end
84-
cu(A::QuantumObject, word_size::Union{Val{32},Val{64}}) = CuArray{_convert_eltype_wordsize(eltype(A), word_size)}(A)
85+
cu(A::QuantumObject, word_size::Union{Val{32},Val{64}}) = QuantumObject(adapt(CuArray{_convert_eltype_wordsize(eltype(A), word_size)}, A.data), A.type, A.dimensions)
8586
function cu(
8687
A::QuantumObject{ObjType,DimsType,<:SparseVector},
8788
word_size::Union{Val{32},Val{64}},

test/ext-test/gpu/cuda_ext.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@
6666
@test typeof(CuSparseMatrixCSR(Xsc).data) == CuSparseMatrixCSR{ComplexF64,Int32}
6767
@test typeof(CuSparseMatrixCSR{ComplexF32}(Xsc).data) == CuSparseMatrixCSR{ComplexF32,Int32}
6868

69+
# type conversion of CUDA Diagonal arrays
70+
@test cu(qeye(10), word_size=Val(32)).data isa Diagonal{ComplexF32, <:CuVector{ComplexF32}}
71+
@test cu(qeye(10), word_size=Val(64)).data isa Diagonal{ComplexF64, <:CuVector{ComplexF64}}
72+
6973
# Sparse To Dense
7074
# @test to_dense(cu(ψsi; word_size = 64)).data isa CuVector{Int64} # TODO: Fix this in CUDA.jl
7175
@test to_dense(cu(ψsf; word_size = 64)).data isa CuVector{Float64}

0 commit comments

Comments
 (0)