@@ -2,12 +2,14 @@ module QuantumToolboxCUDAExt
22
33using QuantumToolbox
44using QuantumToolbox: makeVal, getVal
5- import QuantumToolbox: _sparse_similar, _convert_eltype_wordsize
6- import CUDA: cu, CuArray, allowscalar
5+ import QuantumToolbox: _sparse_similar, _convert_eltype_wordsize, _safe_setindex!
6+ import CUDA: cu, CuArray, allowscalar, @allowscalar , has_cuda
77import CUDA. CUSPARSE: CuSparseVector, CuSparseMatrixCSC, CuSparseMatrixCSR, AbstractCuSparseArray
88import SparseArrays: SparseVector, SparseMatrixCSC, sparse
99import CUDA. Adapt: adapt
1010
11+ export _safe_setindex!
12+
1113allowscalar (false )
1214
1315@doc raw """
@@ -108,10 +110,17 @@ function QuantumToolbox._sparse_similar(A::CuSparseMatrixCSC, rows, cols, vals,
108110 cpu_sparse = sparse (rows, cols, vals, m, n)
109111 return CuSparseMatrixCSC (cpu_sparse)
110112end
111-
112113function QuantumToolbox. _sparse_similar (A:: CuSparseMatrixCSR , rows, cols, vals, m:: Int64 , n:: Int64 )
113114 cpu_sparse = sparse (rows, cols, vals, m, n)
114115 return CuSparseMatrixCSR (cpu_sparse)
115116end
116117
118+ function _safe_setindex! (A, val, idx)
119+ if has_cuda () && isa (A, CuArray)
120+ @allowscalar A[idx] = val
121+ else
122+ A[idx] = val
123+ end
124+ end
125+
117126end
0 commit comments