@@ -3,11 +3,13 @@ module QuantumToolboxCUDAExt
33using QuantumToolbox
44using QuantumToolbox: makeVal, getVal
55import QuantumToolbox: _sparse_similar, _convert_eltype_wordsize
6- import CUDA: cu, CuArray, allowscalar
6+ import CUDA: cu, CuArray, allowscalar, @allowscalar , has_cuda
77import CUDA. CUSPARSE: CuSparseVector, CuSparseMatrixCSC, CuSparseMatrixCSR, AbstractCuSparseArray
8- import SparseArrays: SparseVector, SparseMatrixCSC, sparse
8+ import SparseArrays: SparseVector, SparseMatrixCSC, sparse, spzeros
99import CUDA. Adapt: adapt
1010
11+ export _safe_setindex!
12+
1113allowscalar (false )
1214
1315@doc raw """
@@ -104,5 +106,10 @@ QuantumToolbox.to_dense(::Type{T}, A::AbstractCuSparseArray) where {T<:Number} =
104106
105107QuantumToolbox. _sparse_similar (A:: CuSparseMatrixCSC , args... ) = sparse (args... , fmt = :csc )
106108QuantumToolbox. _sparse_similar (A:: CuSparseMatrixCSR , args... ) = sparse (args... , fmt = :csr )
107-
109+ _sparse_similar (A:: CuSparseMatrixCSC , I:: AbstractVector , J:: AbstractVector , V:: AbstractVector , m:: Int , n:: Int ) =
110+ CuSparseMatrixCSC (sparse (I, J, V, m, n))
111+ _sparse_similar (A:: CuSparseMatrixCSC , m:: Int , n:: Int ) = CuSparseMatrixCSC (spzeros (eltype (A), m, n))
112+ _sparse_similar (A:: CuSparseMatrixCSR , I:: AbstractVector , J:: AbstractVector , V:: AbstractVector , m:: Int , n:: Int ) =
113+ CuSparseMatrixCSR (sparse (I, J, V, m, n))
114+ _sparse_similar (A:: CuSparseMatrixCSR , m:: Int , n:: Int ) = CuSparseMatrixCSR (spzeros (eltype (A), m, n))
108115end
0 commit comments