Skip to content

Commit f8bc54d

Browse files
committed
Implementing a fix to issues in compress
1 parent c618386 commit f8bc54d

File tree

5 files changed

+132
-33
lines changed

5 files changed

+132
-33
lines changed

src/abstracttensortrain.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,25 @@ Subtraction of two tensor trains. If `c = a - b`, then `c(v) ≈ a(v) - b(v)` at
269269
function Base.:-(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
270270
return subtract(lhs, rhs)
271271
end
272+
273+
"""
274+
Squared Frobenius norm of a tensor train.
275+
"""
276+
function LA.norm2(tt::AbstractTensorTrain{V})::Float64 where {V}
277+
function _f(n)::Matrix{V}
278+
t = sitetensor(tt, n)
279+
t3 = reshape(t, size(t)[1], :, size(t)[end])
280+
# (lc, s, rc) * (l, s, r) => (lc, rc, l, r)
281+
tct = _contract(conj.(t3), t3, (2,), (2,))
282+
tct = permutedims(tct, (1, 3, 2, 4))
283+
return reshape(tct, size(tct, 1) * size(tct, 2), size(tct, 3) * size(tct, 4))
284+
end
285+
return real(only(reduce(*, (_f(n) for n in 1:length(tt)))))
286+
end
287+
288+
"""
289+
Frobenius norm of a tensor train.
290+
"""
291+
function LA.norm(tt::AbstractTensorTrain{V})::Float64 where {V}
292+
sqrt(LA.norm2(tt))
293+
end

src/tensortrain.jl

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,25 +82,50 @@ function tensortrain(tci)
8282
end
8383

8484
function _factorize(
85-
A::Matrix{V}, method::Symbol; tolerance::Float64, maxbonddim::Int
85+
A::Matrix{V}, method::Symbol; tolerance::Float64, maxbonddim::Int, leftorthogonal::Bool=false, normalizeerror=true
8686
)::Tuple{Matrix{V},Matrix{V},Int} where {V}
87+
reltol = 1e-14
88+
abstol = 0.0
89+
if normalizeerror
90+
reltol = tolerance
91+
else
92+
abstol = tolerance
93+
end
8794
if method === :LU
88-
factorization = rrlu(A, abstol=tolerance, maxrank=maxbonddim)
95+
factorization = rrlu(A, abstol=abstol, reltol=reltol, maxrank=maxbonddim, leftorthogonal=leftorthogonal)
8996
return left(factorization), right(factorization), npivots(factorization)
9097
elseif method === :CI
91-
factorization = MatrixLUCI(A, abstol=tolerance, maxrank=maxbonddim)
98+
factorization = MatrixLUCI(A, abstol=abstol, reltol=reltol, maxrank=maxbonddim, leftorthogonal=leftorthogonal)
9299
return left(factorization), right(factorization), npivots(factorization)
93100
elseif method === :SVD
94101
factorization = LinearAlgebra.svd(A)
102+
err = [sum(factorization.S[n+1:end] .^ 2) for n in 1:length(factorization.S)]
103+
normalized_err = err ./ sum(factorization.S .^ 2)
104+
105+
#@show normalized_err
106+
#@show sum(factorization.S .^ 2)
107+
#@show err
95108
trunci = min(
96-
replacenothing(findlast(>(tolerance), factorization.S), 1),
109+
replacenothing(findfirst(<(abstol^2), err), length(err)),
110+
replacenothing(findfirst(<(reltol^2), normalized_err), length(normalized_err)),
97111
maxbonddim
98112
)
99-
return (
100-
factorization.U[:, 1:trunci],
101-
Diagonal(factorization.S[1:trunci]) * factorization.Vt[1:trunci, :],
102-
trunci
103-
)
113+
#@show findfirst(<(abstol^2), err)
114+
#@show findfirst(<(reltol^2), normalized_err)
115+
#@show trunci, length(err)
116+
if leftorthogonal
117+
return (
118+
factorization.U[:, 1:trunci],
119+
Diagonal(factorization.S[1:trunci]) * factorization.Vt[1:trunci, :],
120+
trunci
121+
)
122+
else
123+
return (
124+
factorization.U[:, 1:trunci] * Diagonal(factorization.S[1:trunci]),
125+
factorization.Vt[1:trunci, :],
126+
trunci
127+
)
128+
end
104129
else
105130
error("Not implemented yet.")
106131
end
@@ -120,32 +145,41 @@ function compress!(
120145
tt::TensorTrain{V,N},
121146
method::Symbol=:LU;
122147
tolerance::Float64=1e-12,
123-
maxbonddim::Int=typemax(Int)
148+
maxbonddim::Int=typemax(Int),
149+
normalizeerror::Bool=true
124150
) where {V,N}
151+
# From left to right
125152
for ell in 1:length(tt)-1
153+
#println("ell=$ell")
126154
shapel = size(tt.sitetensors[ell])
127155
left, right, newbonddim = _factorize(
128156
reshape(tt.sitetensors[ell], prod(shapel[1:end-1]), shapel[end]),
129-
method; tolerance, maxbonddim
157+
method; tolerance=0.0, maxbonddim=typemax(Int), leftorthogonal=true # no truncation
130158
)
131159
tt.sitetensors[ell] = reshape(left, shapel[1:end-1]..., newbonddim)
132160
shaper = size(tt.sitetensors[ell+1])
133161
nexttensor = right * reshape(tt.sitetensors[ell+1], shaper[1], prod(shaper[2:end]))
134162
tt.sitetensors[ell+1] = reshape(nexttensor, newbonddim, shaper[2:end]...)
135163
end
136164

165+
# From right to left
137166
for ell in length(tt):-1:2
138167
shaper = size(tt.sitetensors[ell])
139168
left, right, newbonddim = _factorize(
140169
reshape(tt.sitetensors[ell], shaper[1], prod(shaper[2:end])),
141-
method; tolerance, maxbonddim
170+
method; tolerance, maxbonddim, normalizeerror, leftorthogonal=false
142171
)
143172
tt.sitetensors[ell] = reshape(right, newbonddim, shaper[2:end]...)
144173
shapel = size(tt.sitetensors[ell-1])
145174
nexttensor = reshape(tt.sitetensors[ell-1], prod(shapel[1:end-1]), shapel[end]) * left
146175
tt.sitetensors[ell-1] = reshape(nexttensor, shapel[1:end-1]..., newbonddim)
147176
end
148177

178+
#println("")
179+
#println("")
180+
#println("")
181+
#println("")
182+
#println("")
149183
nothing
150184
end
151185

@@ -201,6 +235,7 @@ function Base.reverse(tt::AbstractTensorTrain{V}) where {V}
201235
]))
202236
end
203237

238+
204239
"""
205240
Fitting data with a TensorTrain object.
206241
This may be useful when the interpolated function is noisy.

test/runtests.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,22 @@ import TensorCrossInterpolation as TCI
22
using Test
33
using LinearAlgebra
44

5-
include("test_with_aqua.jl")
6-
include("test_with_jet.jl")
7-
include("test_util.jl")
8-
include("test_sweepstrategies.jl")
9-
include("test_indexset.jl")
10-
include("test_cachedfunction.jl")
11-
include("test_matrixci.jl")
12-
include("test_matrixaca.jl")
13-
include("test_matrixlu.jl")
14-
include("test_matrixluci.jl")
15-
include("test_batcheval.jl")
16-
include("test_cachedtensortrain.jl")
17-
include("test_tensorci1.jl")
18-
include("test_tensorci2.jl")
5+
#include("test_with_aqua.jl")
6+
#include("test_with_jet.jl")
7+
#include("test_util.jl")
8+
##include("test_sweepstrategies.jl")
9+
#include("test_indexset.jl")
10+
#include("test_cachedfunction.jl")
11+
#include("test_matrixci.jl")
12+
#include("test_matrixaca.jl")
13+
#include("test_matrixlu.jl")
14+
#include("test_matrixluci.jl")
15+
#include("test_batcheval.jl")
16+
#include("test_cachedtensortrain.jl")
17+
#include("test_tensorci1.jl")
18+
#include("test_tensorci2.jl")
1919
include("test_tensortrain.jl")
20-
include("test_conversion.jl")
21-
include("test_contraction.jl")
22-
include("test_integration.jl")
23-
include("test_globalsearch.jl")
20+
#include("test_conversion.jl")
21+
#include("test_contraction.jl")
22+
#include("test_integration.jl")
23+
#include("test_globalsearch.jl")

test/test_blockstructure.jl

Whitespace-only changes.

test/test_tensortrain.jl

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import TensorCrossInterpolation as TCI
2+
import LinearAlgebra as LA
23
using Random
34
using Zygote
45
using Optim
@@ -48,15 +49,15 @@ end
4849
L = length(linkdims) - 1
4950
localdims = fill(4, L)
5051
tts = TCI.TensorTrain{T,3}([randn(T, linkdims[n], localdims[n], linkdims[n+1]) for n in 1:L])
51-
tto = TCI.TensorTrain{4}(tts, fill([2,2], L))
52+
tto = TCI.TensorTrain{4}(tts, fill([2, 2], L))
5253
tts_reconst = TCI.TensorTrain{3}(tto, localdims)
5354

5455
for n in 1:L
5556
@test all(tts[n] .== tts_reconst[n])
5657
end
5758

58-
@test_throws ErrorException TCI.TensorTrain{4}(tts, fill([2,3], L)) # Wrong shape
59-
@test_throws ErrorException TCI.TensorTrain{4}(tts, fill([1,2,3], L)) # Wrong shape
59+
@test_throws ErrorException TCI.TensorTrain{4}(tts, fill([2, 3], L)) # Wrong shape
60+
@test_throws ErrorException TCI.TensorTrain{4}(tts, fill([1, 2, 3], L)) # Wrong shape
6061
end
6162

6263
@testset "batchevaluate" begin
@@ -176,3 +177,44 @@ end
176177
indicesmultileg = @. collect(zip(indices, indices))
177178
@test ttmultileg2.(indicesmultileg) 2 .* ttmultileg.(indicesmultileg)
178179
end
180+
181+
182+
@testset "norm" begin
183+
T = Float64
184+
sitedims_ = [[2], [2], [2]]
185+
N = length(sitedims_)
186+
bonddims = [1, 1, 1, 1]
187+
188+
tt = TCI.TensorTrain([
189+
ones(bonddims[n], sitedims_[n]..., bonddims[n+1]) for n in 1:N
190+
])
191+
192+
@test LA.norm2(tt) prod(only.(sitedims_))
193+
@test LA.norm2(2 * tt) 4 * prod(only.(sitedims_))
194+
@test LA.norm2(tt) LA.norm(tt)^2
195+
end
196+
197+
@testset "compress! (SVD)" for T in [Float64, ComplexF64]
198+
Random.seed!(1234)
199+
T = Float64
200+
N = 10
201+
sitedims_ = [[2] for _ in 1:N]
202+
χ = 10
203+
204+
tol = 0.1
205+
bonddims = vcat(1, χ * ones(Int, N - 1), 1)
206+
207+
tt = TCI.TensorTrain([
208+
randn(bonddims[n], sitedims_[n]..., bonddims[n+1]) for n in 1:N
209+
])
210+
211+
# normalizeerror=true
212+
tt_compressed = deepcopy(tt)
213+
TCI.compress!(tt_compressed, :SVD; tolerance=tol)
214+
@test sqrt(LA.norm2(tt - tt_compressed) / LA.norm2(tt)) < sqrt(N) * tol
215+
216+
# normalizeerror=false
217+
tt_compressed = deepcopy(tt)
218+
TCI.compress!(tt_compressed, :SVD; tolerance=LA.norm(tt) * tol, normalizeerror=false)
219+
@test sqrt(LA.norm2(tt - tt_compressed) / LA.norm2(tt)) < sqrt(N) * tol
220+
end

0 commit comments

Comments
 (0)