Skip to content

Commit 4e3c0d0

Browse files
committed
Merge branch '44-use-full-pivoting-for-computing-site-tensors-for-tci2' into 'main'
A rare event of non-square pivot matrix Closes #44 See merge request tensors4fields/TensorCrossInterpolation.jl!91
2 parents a80ee50 + 06ed5d8 commit 4e3c0d0

File tree

4 files changed

+160
-39
lines changed

4 files changed

+160
-39
lines changed

src/matrixlu.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,20 @@ mutable struct rrLU{T}
7777
npivot::Int
7878
error::Float64
7979

80+
function rrLU(rowpermutation::Vector{Int}, colpermutation::Vector{Int}, L::Matrix{T}, U::Matrix{T}, leftorthogonal, npivot, error) where {T}
81+
npivot == size(L, 2) || error("L must have the same number of columns as the number of pivots.")
82+
npivot == size(U, 1) || error("U must have the same number of rows as the number of pivots.")
83+
length(rowpermutation) == size(L, 1) || error("rowpermutation must have length equal to the number of pivots.")
84+
length(colpermutation) == size(U, 2) || error("colpermutation must have length equal to the number of pivots.")
85+
new{T}(rowpermutation, colpermutation, L, U, leftorthogonal, npivot, error)
86+
end
87+
8088
function rrLU{T}(nrows::Int, ncols::Int; leftorthogonal::Bool=true) where {T}
8189
new{T}(1:nrows, 1:ncols, zeros(nrows, 0), zeros(0, ncols), leftorthogonal, 0, NaN)
8290
end
8391
end
8492

93+
8594
function rrLU{T}(A::AbstractMatrix{T}; leftorthogonal::Bool=true) where {T}
8695
rrLU{T}(size(A)...; leftorthogonal=leftorthogonal)
8796
end
@@ -389,3 +398,70 @@ A special care is taken for a full-rank matrix: the last pivot error is set to z
389398
function lastpivoterror(lu::rrLU{T})::Float64 where {T}
390399
return lu.error
391400
end
401+
402+
403+
404+
"""
405+
Solve (LU) x = b
406+
407+
L: lower triangular matrix
408+
U: upper triangular matrix
409+
b: right-hand side vector
410+
411+
Return x
412+
413+
Note: Not optimized for performance
414+
"""
415+
function solve(L::Matrix{T}, U::Matrix{T}, b::Matrix{T}) where{T}
416+
N1, N2, N3 = size(L, 1), size(L, 2), size(U, 2)
417+
M = size(b, 2)
418+
419+
# Solve Ly = b
420+
y = zeros(T, N2, M)
421+
for i = 1:N2
422+
y[i, :] .= b[i, :]
423+
for k in 1:M
424+
for j in 1:i-1
425+
y[i, k] -= L[i, j] * y[j, k]
426+
end
427+
end
428+
y[i, :] ./= L[i, i]
429+
end
430+
431+
# Solve Ux = y
432+
x = zeros(T, N3, M)
433+
for i = N3:-1:1
434+
x[i, :] .= y[i, :]
435+
for k in 1:M
436+
for j = i+1:N3
437+
x[i, k] -= U[i, j] * x[j, k]
438+
end
439+
end
440+
x[i, :] ./= U[i, i]
441+
end
442+
443+
return x
444+
end
445+
446+
447+
"""
448+
Override solving Ax = b using LU decomposition
449+
"""
450+
function Base.:\(A::rrLU{T}, b::AbstractMatrix{T}) where{T}
451+
size(A, 1) == size(A, 2) || error("Matrix must be square.")
452+
A.npivot == size(A, 1) || error("rank-deficient matrix is not supportred!")
453+
b_perm = b[A.rowpermutation, :]
454+
x_perm = solve(A.L, A.U, b_perm)
455+
x = similar(x_perm)
456+
for i in 1:size(x, 1)
457+
x[A.colpermutation[i], :] .= x_perm[i, :]
458+
end
459+
return x
460+
end
461+
462+
463+
function Base.transpose(A::rrLU{T}) where{T}
464+
return rrLU(
465+
A.colpermutation, A.rowpermutation,
466+
Matrix(transpose(A.U)), Matrix(transpose(A.L)), !A.leftorthogonal, A.npivot, A.error)
467+
end

src/tensorci2.jl

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ function kronecker(
290290
return MultiIndex[[i, js...] for i in 1:localdim, js in Jset][:]
291291
end
292292

293-
function setT!(
293+
function setsitetensor!(
294294
tci::TensorCI2{ValueType}, b::Int, T::AbstractArray{ValueType,N}
295295
) where {ValueType,N}
296296
tci.sitetensors[b] = reshape(
@@ -328,7 +328,7 @@ end
328328
# Backward compatibility
329329
const rmbadpivots! = sweep0site!
330330

331-
function setT!(
331+
function setsitetensor!(
332332
tci::TensorCI2{ValueType}, f, b::Int; leftorthogonal=true
333333
) where {ValueType}
334334
leftorthogonal || error("leftorthogonal==false is not supported!")
@@ -342,14 +342,16 @@ function setT!(
342342

343343
if (leftorthogonal && b == length(tci)) ||
344344
(!leftorthogonal && b == 1)
345-
setT!(tci, b, Pi1)
345+
setsitetensor!(tci, b, Pi1)
346346
return tci.sitetensors[b]
347347
end
348348

349349
P = reshape(
350350
filltensor(ValueType, f, tci.localdims, tci.Iset[b+1], tci.Jset[b], Val(0)),
351351
length(tci.Iset[b+1]), length(tci.Jset[b]))
352+
length(tci.Iset[b+1]) == length(tci.Jset[b]) || error("Pivot matrix at bond $(b) is not square!")
352353

354+
#Tmat = transpose(transpose(rrlu(P)) \ transpose(Pi1))
353355
Tmat = transpose(transpose(P) \ transpose(Pi1))
354356
tci.sitetensors[b] = reshape(Tmat, length(tci.Iset[b]), tci.localdims[b], length(tci.Iset[b+1]))
355357
return tci.sitetensors[b]
@@ -396,7 +398,7 @@ function sweep1site!(
396398
tci.Iset[b+forwardsweep] = Is[rowindices(luci)]
397399
tci.Jset[b-!forwardsweep] = Js[colindices(luci)]
398400
if updatetensors
399-
setT!(tci, b, forwardsweep ? left(luci) : right(luci))
401+
setsitetensor!(tci, b, forwardsweep ? left(luci) : right(luci))
400402
end
401403
if any(isnan.(tci.sitetensors[b]))
402404
error("Error: NaN in tensor T[$b]")
@@ -417,7 +419,7 @@ function sweep1site!(
417419
end
418420
localtensor = reshape(filltensor(
419421
ValueType, f, tci.localdims, tci.Iset[lastupdateindex], tci.Jset[lastupdateindex], Val(1)), shape)
420-
setT!(tci, lastupdateindex, localtensor)
422+
setsitetensor!(tci, lastupdateindex, localtensor)
421423
end
422424
nothing
423425
end
@@ -560,8 +562,8 @@ function updatepivots!(
560562
tci.Iset[b+1] = Icombined[rowindices(luci)]
561563
tci.Jset[b] = Jcombined[colindices(luci)]
562564
if length(extraIset) == 0 && length(extraJset) == 0
563-
setT!(tci, b, left(luci))
564-
setT!(tci, b + 1, right(luci))
565+
setsitetensor!(tci, b, left(luci))
566+
setsitetensor!(tci, b + 1, right(luci))
565567
end
566568
updateerrors!(tci, b, pivoterrors(luci))
567569
nothing
@@ -744,24 +746,22 @@ function optimize!(
744746
end
745747
end
746748

747-
if rank(tci) > maxbonddim
748-
errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0
749-
abstol = pivottolerance * errornormalization;
750-
sweep2site!(
751-
tci, f, 1;
752-
abstol=abstol,
753-
maxbonddim=maxbonddim,
754-
pivotsearch=pivotsearch,
755-
strictlynested=strictlynested,
756-
verbosity=verbosity
757-
)
758-
end
749+
# Extra one sweep by the 1-site update to
750+
# (1) Remove unnecessary pivots added by global pivots
751+
# Note: a pivot matrix can be non-square after adding global pivots,
752+
# or the bond dimension exceeds maxbonddim
753+
# (2) Compute site tensors
754+
errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0
755+
abstol = pivottolerance * errornormalization;
756+
sweep1site!(
757+
tci,
758+
f,
759+
abstol=abstol,
760+
maxbonddim=maxbonddim,
761+
)
759762

760-
if !issitetensorsavailable(tci)
761-
fillsitetensors!(tci, f)
762-
end
763+
_sanitycheck(tci)
763764

764-
errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0
765765
return ranks, errors ./ errornormalization
766766
end
767767

@@ -993,13 +993,17 @@ end
993993

994994
function fillsitetensors!(
995995
tci::TensorCI2{ValueType}, f) where {ValueType}
996-
#==
997-
for b in 1:length(tci)-1
998-
rmbadpivots!(tci, f, b)
999-
end
1000-
==#
1001996
for b in 1:length(tci)
1002-
setT!(tci, f, b)
997+
setsitetensor!(tci, f, b)
1003998
end
1004999
nothing
10051000
end
1001+
1002+
1003+
function _sanitycheck(tci::TensorCI2{ValueType})::Bool where {ValueType}
1004+
for b in 1:length(tci)-1
1005+
length(tci.Iset[b+1]) == length(tci.Jset[b]) || error("Pivot matrix at bond $(b) is not square!")
1006+
end
1007+
1008+
return true
1009+
end

test/runtests.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ import TensorCrossInterpolation as TCI
22
using Test
33
using LinearAlgebra
44

5-
#==
65
include("test_with_aqua.jl")
76
include("test_with_jet.jl")
87
include("test_util.jl")
@@ -20,5 +19,4 @@ include("test_tensortrain.jl")
2019
include("test_conversion.jl")
2120
include("test_contraction.jl")
2221
include("test_integration.jl")
23-
==#
2422
include("test_globalsearch.jl")

test/test_matrixlu.jl

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Test
22
import TensorCrossInterpolation as TCI
33
using LinearAlgebra
4+
using Random
45

56
@testset "LU decomposition" begin
67
@testset "Argmax finder" begin
@@ -151,9 +152,9 @@ using LinearAlgebra
151152
0.490269 0.810266 0.7946
152153
]
153154
q = [
154-
0.239552 0.306094 0.299063 0.0382492 0.185462 0.0334971 0.697561 0.389596 0.105665 0.0912763
155-
0.0570609 0.56623 0.97183 0.994184 0.371695 0.284437 0.993251 0.902347 0.572944 0.0531369
156-
0.45002 0.461168 0.6086 0.613702 0.543997 0.759954 0.0959818 0.638499 0.407382 0.482592
155+
0.239552 0.306094 0.299063 0.0382492 0.185462 0.0334971 0.697561 0.389596 0.105665 0.0912763
156+
0.0570609 0.56623 0.97183 0.994184 0.371695 0.284437 0.993251 0.902347 0.572944 0.0531369
157+
0.45002 0.461168 0.6086 0.613702 0.543997 0.759954 0.0959818 0.638499 0.407382 0.482592
157158
]
158159

159160
A = p * q
@@ -175,11 +176,11 @@ using LinearAlgebra
175176

176177
@testset "lastpivoterror for limited maxrank or tolerance" begin
177178
A = [
178-
0.433088 0.956638 0.0907974 0.0447859 0.0196053
179-
0.855517 0.782503 0.291197 0.540828 0.358579
180-
0.37455 0.536457 0.205479 0.75896 0.701206
181-
0.47272 0.0172539 0.518177 0.242864 0.461635
182-
0.0676373 0.450878 0.672335 0.77726 0.540691
179+
0.433088 0.956638 0.0907974 0.0447859 0.0196053
180+
0.855517 0.782503 0.291197 0.540828 0.358579
181+
0.37455 0.536457 0.205479 0.75896 0.701206
182+
0.47272 0.0172539 0.518177 0.242864 0.461635
183+
0.0676373 0.450878 0.672335 0.77726 0.540691
183184
]
184185

185186
lu = TCI.rrlu(A, maxrank=2)
@@ -208,4 +209,46 @@ using LinearAlgebra
208209
@test size(lu) == size(A)
209210
@test maximum(abs.(TCI.left(lu) * TCI.right(lu) .- A)) < 1e-3
210211
end
212+
213+
214+
@testset "transpose" begin
215+
Random.seed!(1234)
216+
N1, N2, N3 = 5, 10, 3
217+
A = rand(N1, N2)
218+
219+
tlu = transpose(TCI.rrlu(A))
220+
221+
@test TCI.left(tlu) * TCI.right(tlu) A'
222+
end
223+
224+
@testset "solve by rrLU" begin
225+
Random.seed!(1234)
226+
N1, N2, N3 = 5, 5, 5
227+
M = 2
228+
L = tril(rand(N1, N2))
229+
U = triu(rand(N2, N3))
230+
b = rand(N1, M)
231+
232+
A = L * U
233+
lua = TCI.rrlu(A)
234+
@test TCI.left(lua) * TCI.right(lua) A
235+
236+
@test A * (lua \ b) b
237+
end
238+
239+
#==
240+
@testset "solve by rrLU (large matrix)" begin
241+
M = 1000
242+
N = 1000
243+
244+
L = qr(rand(N, N)).Q # Well-behaved matrix
245+
U = qr(rand(N, N)).Q
246+
b = rand(N, M)
247+
248+
A = L * U
249+
lua = TCI.rrlu(A)
250+
breconst = A * (lua \ b)
251+
@test breconst ≈ b
252+
end
253+
==#
211254
end

0 commit comments

Comments
 (0)