Skip to content

Commit a9d9076

Browse files
authored
Merge pull request #49 from tensor4all/terasaki/optimize-arrlu
Improve `src/matrixlu.jl` to optimize `arrlu`
2 parents b764487 + 0650adb commit a9d9076

File tree

2 files changed

+43
-19
lines changed

2 files changed

+43
-19
lines changed

src/matrixlu.jl

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,19 @@ function rrLU{T}(A::AbstractMatrix{T}; leftorthogonal::Bool=true) where {T}
9696
end
9797

9898
function swaprow!(lu::rrLU{T}, A::AbstractMatrix{T}, a, b) where {T}
99-
lu.rowpermutation[[a, b]] = lu.rowpermutation[[b, a]]
100-
A[[a, b], :] = A[[b, a], :]
99+
lurp = lu.rowpermutation
100+
lurp[a], lurp[b] = lurp[b], lurp[a]
101+
@inbounds for j in axes(A, 2)
102+
A[a, j], A[b, j] = A[b, j], A[a, j]
103+
end
101104
end
102105

103106
function swapcol!(lu::rrLU{T}, A::AbstractMatrix{T}, a, b) where {T}
104-
lu.colpermutation[[a, b]] = lu.colpermutation[[b, a]]
105-
A[:, [a, b]] = A[:, [b, a]]
107+
lucp = lu.colpermutation
108+
lucp[a], lucp[b] = lucp[b], lucp[a]
109+
@inbounds for i in axes(A, 1)
110+
A[i, a], A[i, b] = A[i, b], A[i, a]
111+
end
106112
end
107113

108114
function addpivot!(lu::rrLU{T}, A::AbstractMatrix{T}, newpivot) where {T}
@@ -111,9 +117,9 @@ function addpivot!(lu::rrLU{T}, A::AbstractMatrix{T}, newpivot) where {T}
111117
swapcol!(lu, A, k, newpivot[2])
112118

113119
if lu.leftorthogonal
114-
A[k+1:end, k] /= A[k, k]
120+
A[k+1:end, k] ./= A[k, k]
115121
else
116-
A[k, k+1:end] /= A[k, k]
122+
A[k, k+1:end] ./= A[k, k]
117123
end
118124

119125
# perform BLAS subroutine manually: A <- -x * transpose(y) + A
@@ -139,12 +145,12 @@ function _optimizerrlu!(
139145
reltol::Number=1e-14,
140146
abstol::Number=0.0
141147
) where {T}
142-
maxrank = min(maxrank, size(A)...)
148+
maxrank = min(maxrank, size(A, 1), size(A, 2))
143149
maxerror = 0.0
144150
while lu.npivot < maxrank
145151
k = lu.npivot + 1
146152
newpivot = submatrixargmax(abs2, A, k)
147-
lu.error = abs(A[newpivot...])
153+
lu.error = abs(A[newpivot[1], newpivot[2]])
148154
# Add at least 1 pivot to get a well-defined L * U
149155
if (abs(lu.error) < reltol * maxerror || abs(lu.error) < abstol) && lu.npivot > 0
150156
break
@@ -153,8 +159,8 @@ function _optimizerrlu!(
153159
addpivot!(lu, A, newpivot)
154160
end
155161

156-
lu.L = tril(A[:, 1:lu.npivot])
157-
lu.U = triu(A[1:lu.npivot, :])
162+
lu.L = tril(@view A[:, 1:lu.npivot])
163+
lu.U = triu(@view A[1:lu.npivot, :])
158164
if any(isnan.(lu.L))
159165
error("lu.L contains NaNs")
160166
end
@@ -271,16 +277,16 @@ function arrlu(
271277
I2 = setdiff(1:matrixsize[1], I0)
272278
lu.rowpermutation = vcat(I0, I2)
273279
L2 = _batchf(I2, J0)
274-
cols2Lmatrix!(L2, lu.U[1:lu.npivot, 1:lu.npivot], leftorthogonal)
275-
lu.L = vcat(lu.L[1:lu.npivot, 1:lu.npivot], L2)
280+
cols2Lmatrix!(L2, (@view lu.U[1:lu.npivot, 1:lu.npivot]), leftorthogonal)
281+
lu.L = vcat((@view lu.L[1:lu.npivot, 1:lu.npivot]), L2)
276282
end
277283

278284
if size(lu.U, 2) < matrixsize[2]
279285
J2 = setdiff(1:matrixsize[2], J0)
280286
lu.colpermutation = vcat(J0, J2)
281287
U2 = _batchf(I0, J2)
282-
rows2Umatrix!(U2, lu.L[1:lu.npivot, 1:lu.npivot], leftorthogonal)
283-
lu.U = hcat(lu.U[1:lu.npivot, 1:lu.npivot], U2)
288+
rows2Umatrix!(U2, (@view lu.L[1:lu.npivot, 1:lu.npivot]), leftorthogonal)
289+
lu.U = hcat((@view lu.U[1:lu.npivot, 1:lu.npivot]), U2)
284290
end
285291

286292
return lu
@@ -313,8 +319,17 @@ function cols2Lmatrix!(C::AbstractMatrix, P::AbstractMatrix, leftorthogonal::Boo
313319
end
314320

315321
for k in axes(P, 1)
316-
C[:, k] /= P[k, k]
317-
C[:, k+1:end] -= C[:, k] * transpose(P[k, k+1:end])
322+
C[:, k] ./= P[k, k]
323+
# C[:, k+1:end] .-= C[:, k] * transpose(P[k, k+1:end])
324+
x = @view C[:, k]
325+
y = @view P[k, k+1:end]
326+
= @view C[:, k+1:end]
327+
@inbounds for j in eachindex(axes(C̃, 2), y)
328+
for i in eachindex(axes(C̃, 1), x)
329+
# update `C[:, k+1:end]`
330+
C̃[i, j] -= x[i] * y[j]
331+
end
332+
end
318333
end
319334
return C
320335
end
@@ -327,8 +342,17 @@ function rows2Umatrix!(R::AbstractMatrix, P::AbstractMatrix, leftorthogonal::Boo
327342
end
328343

329344
for k in axes(P, 1)
330-
R[k, :] /= P[k, k]
331-
R[k+1:end, :] -= P[k+1:end, k] * transpose(R[k, :])
345+
R[k, :] ./= P[k, k]
346+
# R[k+1:end, :] -= P[k+1:end, k] * transpose(R[k, :])
347+
x = @view P[k+1:end, k]
348+
y = @view R[k, :]
349+
= @view R[k+1:end, :]
350+
@inbounds for j in eachindex(axes(R̃, 2), y)
351+
for i in eachindex(axes(R̃, 1), x)
352+
# update R[k+1:end, :]
353+
R̃[i, j] -= x[i] * y[j]
354+
end
355+
end
332356
end
333357
return R
334358
end

src/util.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ end
5353

5454
function pushrandomsubset!(subset, set, n::Int)
5555
topush = randomsubset(setdiff(set, subset), n)
56-
push!(subset, topush...)
56+
append!(subset, topush)
5757
nothing
5858
end
5959

0 commit comments

Comments
 (0)