Skip to content

Commit a06fbda

Browse files
committed
fixup
1 parent c0bc0e7 commit a06fbda

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

ext/QuantumToolboxCUDAExt.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ module QuantumToolboxCUDAExt
22

33
using QuantumToolbox
44
using QuantumToolbox: makeVal, getVal
5-
import QuantumToolbox: _sparse_similar, _convert_eltype_wordsize
6-
import CUDA: cu, CuArray, allowscalar
5+
import QuantumToolbox: _sparse_similar, _convert_eltype_wordsize, _safe_setindex!
6+
import CUDA: cu, CuArray, allowscalar, @allowscalar, has_cuda
77
import CUDA.CUSPARSE: CuSparseVector, CuSparseMatrixCSC, CuSparseMatrixCSR, AbstractCuSparseArray
88
import SparseArrays: SparseVector, SparseMatrixCSC, sparse
99
import CUDA.Adapt: adapt
1010

11+
export _safe_setindex!
12+
1113
allowscalar(false)
1214

1315
@doc raw"""
@@ -108,10 +110,17 @@ function QuantumToolbox._sparse_similar(A::CuSparseMatrixCSC, rows, cols, vals,
108110
cpu_sparse = sparse(rows, cols, vals, m, n)
109111
return CuSparseMatrixCSC(cpu_sparse)
110112
end
111-
112113
function QuantumToolbox._sparse_similar(A::CuSparseMatrixCSR, rows, cols, vals, m::Int64, n::Int64)
113114
cpu_sparse = sparse(rows, cols, vals, m, n)
114115
return CuSparseMatrixCSR(cpu_sparse)
115116
end
116117

118+
function _safe_setindex!(A, val, idx)
119+
if has_cuda() && isa(A, CuArray)
120+
@allowscalar A[idx] = val
121+
else
122+
A[idx] = val
123+
end
124+
end
125+
117126
end

src/steadystate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ function _steadystate_fourier(
417417

418418
v0 = _dense_similar(L_0_mat, n_fourier * N)
419419
fill!(v0, 0)
420-
@allowscalar v0[n_max*N+1] = weight
420+
_safe_setindex!(v0, weight, n_max * N + 1)
421421

422422
if !isnothing(solver.Pl)
423423
kwargs = merge((; kwargs...), (Pl = solver.Pl(M),))

src/utilities.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ _dense_similar(A::AbstractSparseMatrix, args...) = similar(nonzeros(A), args...)
135135

136136
_sparse_similar(A::AbstractArray, args...) = sparse(args...)
137137
_sparse_similar(A::AbstractArray, m::Int, n::Int) = spzeros(eltype(A), m, n)
138+
function _safe_setindex! end
138139

139140
_Ginibre_ensemble(n::Int, rank::Int = n) = randn(ComplexF64, n, rank) / sqrt(n)
140141

0 commit comments

Comments
 (0)