Skip to content

Commit 8cb4e17

Browse files
committed
polish
1 parent 30ece37 commit 8cb4e17

File tree

3 files changed

+29
-44
lines changed

3 files changed

+29
-44
lines changed

ext/QuantumToolboxCUDAExt.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module QuantumToolboxCUDAExt
22

33
using QuantumToolbox
44
using QuantumToolbox: makeVal, getVal
5-
import QuantumToolbox: _sparse_similar, _convert_eltype_wordsize
5+
import QuantumToolbox: _sparse_similar, _convert_eltype_wordsize, allowed_setindex!
66
import CUDA: cu, CuArray, allowscalar, @allowscalar, has_cuda
77
import CUDA.CUSPARSE: CuSparseVector, CuSparseMatrixCSC, CuSparseMatrixCSR, AbstractCuSparseArray
88
import SparseArrays: SparseVector, SparseMatrixCSC, sparse, spzeros
@@ -104,10 +104,11 @@ QuantumToolbox.to_dense(::Type{T}, A::AbstractCuSparseArray) where {T<:Number} =
104104

105105
QuantumToolbox._sparse_similar(A::CuSparseMatrixCSC, args...) = sparse(args..., fmt = :csc)
106106
QuantumToolbox._sparse_similar(A::CuSparseMatrixCSR, args...) = sparse(args..., fmt = :csr)
107-
_sparse_similar(A::CuSparseMatrixCSC, I::AbstractVector, J::AbstractVector, V::AbstractVector, m::Int, n::Int) =
107+
QuantumToolbox._sparse_similar(A::CuSparseMatrixCSC, I::AbstractVector, J::AbstractVector, V::AbstractVector, m::Int, n::Int) =
108108
CuSparseMatrixCSC(sparse(I, J, V, m, n))
109-
_sparse_similar(A::CuSparseMatrixCSC, m::Int, n::Int) = CuSparseMatrixCSC(spzeros(eltype(A), m, n))
110-
_sparse_similar(A::CuSparseMatrixCSR, I::AbstractVector, J::AbstractVector, V::AbstractVector, m::Int, n::Int) =
109+
QuantumToolbox._sparse_similar(A::CuSparseMatrixCSC, m::Int, n::Int) = CuSparseMatrixCSC(spzeros(eltype(A), m, n))
110+
QuantumToolbox._sparse_similar(A::CuSparseMatrixCSR, I::AbstractVector, J::AbstractVector, V::AbstractVector, m::Int, n::Int) =
111111
CuSparseMatrixCSR(sparse(I, J, V, m, n))
112-
_sparse_similar(A::CuSparseMatrixCSR, m::Int, n::Int) = CuSparseMatrixCSR(spzeros(eltype(A), m, n))
112+
QuantumToolbox._sparse_similar(A::CuSparseMatrixCSR, m::Int, n::Int) = CuSparseMatrixCSR(spzeros(eltype(A), m, n))
113+
QuantumToolbox.allowed_setindex!(A::AbstractCuSparseArray, v, I...) = @allowscalar A[I...] = v
113114
end

src/steadystate.jl

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -369,63 +369,45 @@ function _steadystate_fourier(
369369
n_fourier = 2 * n_max + 1
370370
n_list = (-n_max):n_max
371371

372-
weight = 1
373-
rows = _dense_similar(L_0_mat, Ns)
374-
cols = _dense_similar(L_0_mat, Ns)
375-
vals = _dense_similar(L_0_mat, Ns)
376-
fill!(rows, 1)
377-
for j in 1:Ns
378-
cols[j] = Ns * (j - 1) + j
379-
end
380-
fill!(vals, weight)
372+
weight = one(T)
373+
rows = ones(Int, Ns)
374+
cols = [Ns * (j - 1) + j for j in 1:Ns]
375+
vals = fill(weight, Ns)
381376
Mn = _sparse_similar(L_0_mat, rows, cols, vals, N, N)
382377
L = L_0_mat + Mn
383378

384379
M = _sparse_similar(L_0_mat, n_fourier * N, n_fourier * N)
385-
386-
# Add superdiagonal blocks (L_m)
387380
for i in 1:(n_fourier-1)
388-
rows_block = _dense_similar(L_0_mat, N)
389-
cols_block = _dense_similar(L_0_mat, N)
390-
fill!(rows_block, i)
391-
fill!(cols_block, i+1)
392-
block = _sparse_similar(L_0_mat, rows_block, cols_block, ones(N), n_fourier, n_fourier)
393-
M += kron(block, L_m_mat)
381+
M_block = _sparse_similar(L_0_mat, N, N)
382+
copyto!(M_block, L_p_mat)
383+
M[(i*N+1):((i+1)*N), ((i-1)*N+1):(i*N)] = M_block
384+
M_block = _sparse_similar(L_0_mat, N, N)
385+
copyto!(M_block, L_m_mat)
386+
M[((i-1)*N+1):(i*N), (i*N+1):((i+1)*N)] = M_block
394387
end
395388

396-
# Add subdiagonal blocks (L_p)
397-
for i in 1:(n_fourier-1)
398-
rows_block = _dense_similar(L_0_mat, N)
399-
cols_block = _dense_similar(L_0_mat, N)
400-
fill!(rows_block, i+1)
401-
fill!(cols_block, i)
402-
block = _sparse_similar(L_0_mat, rows_block, cols_block, ones(N), n_fourier, n_fourier)
403-
M += kron(block, L_p_mat)
404-
end
405-
406-
# Add diagonal blocks (L - i*ωd*n)
407389
for i in 1:n_fourier
408390
n = n_list[i]
409-
block_diag = L - 1im * ωd * n * I
410-
rows_block = _dense_similar(L_0_mat, N)
411-
cols_block = _dense_similar(L_0_mat, N)
412-
fill!(rows_block, i)
413-
fill!(cols_block, i)
414-
block = _sparse_similar(L_0_mat, rows_block, cols_block, ones(N), n_fourier, n_fourier)
415-
M += kron(block, block_diag)
391+
M_block = _sparse_similar(L_0_mat, N, N)
392+
copyto!(M_block, L)
393+
M_block -= 1im * ωd * n * sparse(I, N, N)
394+
M[((i-1)*N+1):(i*N), ((i-1)*N+1):(i*N)] = M_block
416395
end
417396

418-
v0 = _dense_similar(L_0_mat, n_fourier * N)
419-
fill!(v0, 0)
420-
allowed_setindex!(v0, weight, n_max * N + 1)
397+
v0 = similar(L_0_mat, n_fourier * N)
398+
fill!(v0, zero(T))
399+
target_idx = n_max*N + 1
400+
QuantumToolbox.allowed_setindex!(v0, weight, target_idx)
421401

402+
(haskey(kwargs, :Pl) || haskey(kwargs, :Pr)) && error("The use of preconditioners must be defined in the solver.")
422403
if !isnothing(solver.Pl)
423404
kwargs = merge((; kwargs...), (Pl = solver.Pl(M),))
424405
elseif isa(M, SparseMatrixCSC)
425406
kwargs = merge((; kwargs...), (Pl = ilu(M, τ = 0.01),))
426407
end
427408
!isnothing(solver.Pr) && (kwargs = merge((; kwargs...), (Pr = solver.Pr(M),)))
428-
kwargs = merge((abstol = tol, reltol = tol), kwargs)
409+
!haskey(kwargs, :abstol) && (kwargs = merge((; kwargs...), (abstol = tol,)))
410+
!haskey(kwargs, :reltol) && (kwargs = merge((; kwargs...), (reltol = tol,)))
429411

430412
prob = LinearProblem(M, v0)
431413
ρtot = solve(prob, solver.alg; kwargs...).u

src/utilities.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,5 @@ _convert_eltype_wordsize(::Type{T}, ::Val{64}) where {T<:AbstractFloat} = Float6
200200
_convert_eltype_wordsize(::Type{T}, ::Val{32}) where {T<:AbstractFloat} = Float32
201201
_convert_eltype_wordsize(::Type{Complex{T}}, ::Val{64}) where {T<:Union{Int,AbstractFloat}} = ComplexF64
202202
_convert_eltype_wordsize(::Type{Complex{T}}, ::Val{32}) where {T<:Union{Int,AbstractFloat}} = ComplexF32
203+
204+
function allowed_setindex! end

0 commit comments

Comments
 (0)