Skip to content

Commit e901cf9

Browse files
committed
enhancements
1 parent 8cb4e17 commit e901cf9

File tree

3 files changed

+95
-59
lines changed

3 files changed

+95
-59
lines changed

ext/QuantumToolboxCUDAExt.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ module QuantumToolboxCUDAExt
22

33
using QuantumToolbox
44
using QuantumToolbox: makeVal, getVal
5-
import QuantumToolbox: _sparse_similar, _convert_eltype_wordsize, allowed_setindex!
6-
import CUDA: cu, CuArray, allowscalar, @allowscalar, has_cuda
5+
import QuantumToolbox: _sparse_similar, _convert_eltype_wordsize
6+
import CUDA: cu, CuArray, allowscalar
77
import CUDA.CUSPARSE: CuSparseVector, CuSparseMatrixCSC, CuSparseMatrixCSR, AbstractCuSparseArray
88
import SparseArrays: SparseVector, SparseMatrixCSC, sparse, spzeros
99
import CUDA.Adapt: adapt
@@ -104,11 +104,22 @@ QuantumToolbox.to_dense(::Type{T}, A::AbstractCuSparseArray) where {T<:Number} =
104104

105105
QuantumToolbox._sparse_similar(A::CuSparseMatrixCSC, args...) = sparse(args..., fmt = :csc)
106106
QuantumToolbox._sparse_similar(A::CuSparseMatrixCSR, args...) = sparse(args..., fmt = :csr)
107-
QuantumToolbox._sparse_similar(A::CuSparseMatrixCSC, I::AbstractVector, J::AbstractVector, V::AbstractVector, m::Int, n::Int) =
108-
CuSparseMatrixCSC(sparse(I, J, V, m, n))
107+
QuantumToolbox._sparse_similar(
108+
A::CuSparseMatrixCSC,
109+
I::AbstractVector,
110+
J::AbstractVector,
111+
V::AbstractVector,
112+
m::Int,
113+
n::Int,
114+
) = CuSparseMatrixCSC(sparse(I, J, V, m, n))
109115
QuantumToolbox._sparse_similar(A::CuSparseMatrixCSC, m::Int, n::Int) = CuSparseMatrixCSC(spzeros(eltype(A), m, n))
110-
QuantumToolbox._sparse_similar(A::CuSparseMatrixCSR, I::AbstractVector, J::AbstractVector, V::AbstractVector, m::Int, n::Int) =
111-
CuSparseMatrixCSR(sparse(I, J, V, m, n))
116+
QuantumToolbox._sparse_similar(
117+
A::CuSparseMatrixCSR,
118+
I::AbstractVector,
119+
J::AbstractVector,
120+
V::AbstractVector,
121+
m::Int,
122+
n::Int,
123+
) = CuSparseMatrixCSR(sparse(I, J, V, m, n))
112124
QuantumToolbox._sparse_similar(A::CuSparseMatrixCSR, m::Int, n::Int) = CuSparseMatrixCSR(spzeros(eltype(A), m, n))
113-
QuantumToolbox.allowed_setindex!(A::AbstractCuSparseArray, v, I...) = @allowscalar A[I...] = v
114125
end

src/steadystate.jl

Lines changed: 77 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -345,86 +345,113 @@ function steadystate_fourier(
345345
return _steadystate_fourier(L_0, L_p, L_m, ωd, solver; n_max = n_max, tol = tol, kwargs...)
346346
end
347347

348+
import SparseArrays: sparse, spdiagm, I
349+
348350
function _steadystate_fourier(
349-
L_0::QuantumObject{SuperOperator},
350-
L_p::QuantumObject{SuperOperator},
351-
L_m::QuantumObject{SuperOperator},
351+
L0::QuantumObject{SuperOperator},
352+
Lp::QuantumObject{SuperOperator},
353+
Lm::QuantumObject{SuperOperator},
352354
ωd::Number,
353355
solver::SteadyStateLinearSolver;
354356
n_max::Integer = 1,
355357
tol::R = 1e-8,
356358
kwargs...,
357359
) where {R<:Real}
358-
T1 = eltype(L_0)
359-
T2 = eltype(L_p)
360-
T3 = eltype(L_m)
361-
T = promote_type(T1, T2, T3)
360+
A0 = get_data(L0)
361+
Ap = get_data(Lp)
362+
Am = get_data(Lm)
362363

363-
L_0_mat = get_data(L_0)
364-
L_p_mat = get_data(L_p)
365-
L_m_mat = get_data(L_m)
366-
367-
N = size(L_0_mat, 1)
364+
N = size(A0, 1)
368365
Ns = isqrt(N)
369-
n_fourier = 2 * n_max + 1
366+
n_fourier = 2*n_max + 1
370367
n_list = (-n_max):n_max
371368

369+
T = promote_type(eltype(A0), eltype(Ap), eltype(Am))
372370
weight = one(T)
373-
rows = ones(Int, Ns)
374-
cols = [Ns * (j - 1) + j for j in 1:Ns]
375-
vals = fill(weight, Ns)
376-
Mn = _sparse_similar(L_0_mat, rows, cols, vals, N, N)
377-
L = L_0_mat + Mn
378-
379-
M = _sparse_similar(L_0_mat, n_fourier * N, n_fourier * N)
380-
for i in 1:(n_fourier-1)
381-
M_block = _sparse_similar(L_0_mat, N, N)
382-
copyto!(M_block, L_p_mat)
383-
M[(i*N+1):((i+1)*N), ((i-1)*N+1):(i*N)] = M_block
384-
M_block = _sparse_similar(L_0_mat, N, N)
385-
copyto!(M_block, L_m_mat)
386-
M[((i-1)*N+1):(i*N), (i*N+1):((i+1)*N)] = M_block
387-
end
388371

389-
for i in 1:n_fourier
390-
n = n_list[i]
391-
M_block = _sparse_similar(L_0_mat, N, N)
392-
copyto!(M_block, L)
393-
M_block -= 1im * ωd * n * sparse(I, N, N)
394-
M[((i-1)*N+1):(i*N), ((i-1)*N+1):(i*N)] = M_block
372+
# Stabilization block Mn = weight * I (size N)
373+
rows = _dense_similar(A0, Ns)
374+
cols = _dense_similar(A0, Ns)
375+
vals = _dense_similar(A0, Ns)
376+
fill!(rows, 1)
377+
for j in 1:Ns
378+
cols[j] = Ns*(j-1) + j
379+
vals[j] = weight
380+
end
381+
Mn = _sparse_similar(A0, rows, cols, vals, N, N)
382+
383+
# Base Liouvillian L = A0 + Mn
384+
L = A0 + Mn
385+
386+
# Build off-diagonal blocks via Kron
387+
# Kp shifts down, Km shifts up in Fourier index
388+
Kp = _sparse_similar(A0, spdiagm(-1 => ones(T, n_fourier-1)))
389+
Km = _sparse_similar(A0, spdiagm(1 => ones(T, n_fourier-1)))
390+
M = kron(Kp, Ap) + kron(Km, Am)
391+
392+
# Precompute identity block on N×N
393+
# identity in same sparse format as A0
394+
Id_vec = _dense_similar(A0, N)
395+
fill!(Id_vec, one(T))
396+
Id = _sparse_similar(A0, spdiagm(0 => Id_vec))
397+
398+
# Build diagonal blocks via Kron
399+
for (i, n) in enumerate(n_list)
400+
# projector onto Fourier index i
401+
Pi = sparse([i], [i], one(T), n_fourier, n_fourier)
402+
Pi_blk = _sparse_similar(A0, Pi)
403+
# block = L - i*ωd*n*I_N
404+
D = L - (1im * ωd * n) * Id
405+
D_blk = _sparse_similar(A0, D)
406+
M += kron(Pi_blk, D_blk)
395407
end
396408

397-
v0 = similar(L_0_mat, n_fourier * N)
409+
# Initialize RHS vector v0 v0
410+
v0 = _dense_similar(A0, n_fourier * N)
398411
fill!(v0, zero(T))
399-
target_idx = n_max*N + 1
400-
QuantumToolbox.allowed_setindex!(v0, weight, target_idx)
412+
tidx = n_max*N + 1
413+
allowed_setindex!(v0, weight, tidx)
401414

415+
# Preconditioners
402416
(haskey(kwargs, :Pl) || haskey(kwargs, :Pr)) && error("The use of preconditioners must be defined in the solver.")
403417
if !isnothing(solver.Pl)
404-
kwargs = merge((; kwargs...), (Pl = solver.Pl(M),))
418+
kwargs = (; kwargs..., Pl = solver.Pl(M))
405419
elseif isa(M, SparseMatrixCSC)
406-
kwargs = merge((; kwargs...), (Pl = ilu(M, τ = 0.01),))
420+
kwargs = (; kwargs..., Pl = ilu(M, τ = 0.01))
421+
end
422+
if !isnothing(solver.Pr)
423+
kwargs = (; kwargs..., Pr = solver.Pr(M))
424+
end
425+
if !haskey(kwargs, :abstol)
426+
kwargs = (; kwargs..., abstol = tol)
427+
end
428+
if !haskey(kwargs, :reltol)
429+
kwargs = (; kwargs..., reltol = tol)
407430
end
408-
!isnothing(solver.Pr) && (kwargs = merge((; kwargs...), (Pr = solver.Pr(M),)))
409-
!haskey(kwargs, :abstol) && (kwargs = merge((; kwargs...), (abstol = tol,)))
410-
!haskey(kwargs, :reltol) && (kwargs = merge((; kwargs...), (reltol = tol,)))
411431

432+
# Solve linear system
433+
prob = LinearProblem(M, v0)
434+
ρtot = solve(prob, solver.alg; kwargs...).u
435+
prob = LinearProblem(M, v0)
436+
ρtot = solve(prob, solver.alg; kwargs...).u
412437
prob = LinearProblem(M, v0)
413438
ρtot = solve(prob, solver.alg; kwargs...).u
414439

440+
# Extract ρ0 and normalize
415441
offset1 = n_max * N
416442
offset2 = (n_max + 1) * N
417-
ρ0 = reshape(ρtot[(offset1+1):offset2], Ns, Ns)
418-
ρ0_tr = tr(ρ0)
419-
ρ0 = ρ0 / ρ0_tr
420-
ρ0 = QuantumObject((ρ0 + ρ0') / 2, type = Operator(), dims = L_0.dimensions)
421-
ρtot = ρtot / ρ0_tr
443+
blk0 = reshape(ρtot[(offset1+1):offset2], Ns, Ns)
444+
tr0 = tr(blk0)
445+
blk0 ./= tr0
446+
ρ0 = QuantumObject((blk0 + blk0')/2, type = Operator(), dims = L0.dimensions)
422447

448+
# Collect higher-order Fourier components
423449
ρ_list = [ρ0]
424450
for i in 0:(n_max-1)
425-
ρi_m = reshape(ρtot[(offset1-(i+1)*N+1):(offset1-i*N)], Ns, Ns)
426-
ρi_m = QuantumObject(ρi_m, type = Operator(), dims = L_0.dimensions)
427-
push!(ρ_list, ρi_m)
451+
idx1 = offset1 - (i+1)*N + 1
452+
idx2 = offset1 - i*N
453+
blk = reshape(ρtot[idx1:idx2], Ns, Ns)
454+
push!(ρ_list, QuantumObject(blk, type = Operator(), dims = L0.dimensions))
428455
end
429456

430457
return ρ_list

src/utilities.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,5 +200,3 @@ _convert_eltype_wordsize(::Type{T}, ::Val{64}) where {T<:AbstractFloat} = Float6
200200
_convert_eltype_wordsize(::Type{T}, ::Val{32}) where {T<:AbstractFloat} = Float32
201201
_convert_eltype_wordsize(::Type{Complex{T}}, ::Val{64}) where {T<:Union{Int,AbstractFloat}} = ComplexF64
202202
_convert_eltype_wordsize(::Type{Complex{T}}, ::Val{32}) where {T<:Union{Int,AbstractFloat}} = ComplexF32
203-
204-
function allowed_setindex! end

0 commit comments

Comments
 (0)