Skip to content

Commit aaa26ee

Browse files
committed
🎮 GPU support for Lanczos
1 parent fce5537 commit aaa26ee

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

src/spectrum.jl

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -171,54 +171,53 @@ function _spectrum(
171171
fT = _FType(L)
172172
cT = _CType(L)
173173

174-
# Handle input frequency range
175-
ωList = convert(Vector{fT}, ωlist) # Convert it to support GPUs and avoid type instabilities
176-
Length = length(ωList)
177-
#spec = Vector{fT}(undef, Length)
178-
179174
# Calculate |v₁> = B|ρss>
180175
ρss = mat2vec(steadystate(L))
181-
vₖ = Array{cT}((spre(B) * ρss).data)
176+
vₖ = (spre(B) * ρss).data
177+
178+
# Define (possibly GPU) vector type
179+
vT = typeof(vₖ)
182180

183181
# Calculate <w₁| = <I|A
184182
D = prod(L.dimensions)
185183
Ivec = SparseVector(D^2, [1 + n * (D + 1) for n in 0:(D-1)], ones(cT, D)) # same as vec(system_identity_matrix)
186-
wₖ = transpose(typeof(vₖ)(Ivec)) * spre(A).data
184+
wₖ = transpose(vT(Ivec)) * spre(A).data
187185

188186
# Store the norm of the Green's function before renormalizing |v₁> and <w₁|
189187
gfNorm = abs(wₖ * vₖ)
190188
vₖ ./= sqrt(gfNorm)
191189
wₖ ./= sqrt(gfNorm)
192190

193-
# println(" type: $(typeof(vₖ))")
194-
# println(" type: $(typeof(wₖ))")
191+
# Handle input frequency range
192+
ωList = vT(convert(Vector{fT}, ωlist)) # Make sure they're real frequencies and potentially on GPU
193+
Length = length(ωList)
195194

196195
# Current and previous estimates of the spectrum
197-
lanczosFactor = zeros(cT, Length)
198-
lanczosFactor₋₁ = zeros(cT, Length)
196+
lanczosFactor = vT(zeros(cT, Length))
197+
lanczosFactor₋₁ = vT(zeros(cT, Length))
199198

200199
# Tridiagonal matrix elements
201200
αₖ = cT( 0)
202201
βₖ = cT(-1)
203202
δₖ = cT(+1)
204203

205204
# Current and up to second-to-last A and B Euler sequences
206-
A₋₂ = ones(cT, Length)
207-
A₋₁ = zeros(cT, Length)
208-
Aₖ = zeros(cT, Length)
209-
B₋₂ = zeros(cT, Length)
210-
B₋₁ = ones(cT, Length)
211-
Bₖ = zeros(cT, Length)
205+
A₋₂ = vT( ones(cT, Length))
206+
A₋₁ = vT(zeros(cT, Length))
207+
Aₖ = vT(zeros(cT, Length))
208+
B₋₂ = vT(zeros(cT, Length))
209+
B₋₁ = vT( ones(cT, Length))
210+
Bₖ = vT(zeros(cT, Length))
212211

213212
# Maximum norm and residue
214-
maxNorm = zeros(cT, length(ωList))
213+
maxNorm = vT(zeros(cT, length(ωList)))
215214
maxResidue = fT(0.0)
216215

217216
# Previous and next left/right Krylov vectors
218-
v₋₁ = zeros(cT, (D^2, 1))
219-
v₊₁ = zeros(cT, (D^2, 1))
220-
w₋₁ = zeros(cT, (1, D^2))
221-
w₊₁ = zeros(cT, (1, D^2))
217+
v₋₁ = vT(zeros(cT, D^2))
218+
v₊₁ = vT(zeros(cT, D^2))
219+
w₋₁ = vT(zeros(cT, D^2))'
220+
w₊₁ = vT(zeros(cT, D^2))'
222221

223222
# Frequency of renormalization
224223
renormFrequency::typeof(solver.maxiter) = 1

0 commit comments

Comments
 (0)