Skip to content

Commit c38bb01

Browse files
committed
added CI and SVD based truncations
1 parent 1eb8ec6 commit c38bb01

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

src/tensortrain.jl

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,33 @@ function recompress!(
6161
tolerance::Float64=1e-12, maxbonddim=typemax(Int),
6262
method::Symbol=:LU
6363
) where {V}
64-
if method !== :LU
65-
error("Not implemented yet.")
64+
function factorize(A::Matrix{V})
65+
if method === :LU
66+
factorization = rrlu(A, abstol=tolerance, maxrank=maxbonddim)
67+
return left(factorization), right(factorization), npivots(factorization)
68+
elseif method === :CI
69+
factorization = MatrixLUCI(A, abstol=tolerance, maxrank=maxbonddim)
70+
return left(factorization), right(factorization), npivots(factorization)
71+
elseif method === :SVD
72+
factorization = LinearAlgebra.svd(A)
73+
trunci = min(findlast(>(tolerance), factorization.S), maxbonddim)
74+
return (
75+
factorization.U[:, 1:trunci],
76+
Diagonal(factorization.S[1:trunci]) * factorization.Vt[1:trunci, :],
77+
trunci
78+
)
79+
else
80+
error("Not implemented yet.")
81+
end
6682
end
6783

6884
for ell in 1:length(tt)-1
6985
shapel = size(tt.sitetensors[ell])
70-
lu = rrlu(
71-
reshape(tt.sitetensors[ell], prod(shapel[1:end-1]), shapel[end]);
72-
abstol=tolerance,
73-
maxrank=maxbonddim
74-
)
75-
tt.sitetensors[ell] = reshape(left(lu), shapel[1:end-1]..., npivots(lu))
86+
left, right, newbonddim = factorize(reshape(tt.sitetensors[ell], prod(shapel[1:end-1]), shapel[end]))
87+
tt.sitetensors[ell] = reshape(left, shapel[1:end-1]..., newbonddim)
7688
shaper = size(tt.sitetensors[ell+1])
77-
nexttensor = right(lu) * reshape(tt.sitetensors[ell+1], shaper[1], prod(shaper[2:end]))
78-
tt.sitetensors[ell+1] = reshape(nexttensor, npivots(lu), shaper[2:end]...)
89+
nexttensor = right * reshape(tt.sitetensors[ell+1], shaper[1], prod(shaper[2:end]))
90+
tt.sitetensors[ell+1] = reshape(nexttensor, newbonddim, shaper[2:end]...)
7991
end
8092
nothing
8193
end
@@ -107,7 +119,7 @@ function to_tensors(obj::TensorTrainFit{ValueType}, x::Vector{ValueType}) where
107119
reshape(
108120
x[obj.offsets[n]+1:obj.offsets[n+1]],
109121
size(obj.tt[n])
110-
)
122+
)
111123
for n in 1:length(obj.tt)
112124
]
113125
end
@@ -116,5 +128,5 @@ _evaluate(tt, indexset) = only(prod(T[:, i, :] for (T, i) in zip(tt, indexset)))
116128

117129
function (obj::TensorTrainFit{ValueType})(x::Vector{ValueType}) where {ValueType}
118130
tensors = to_tensors(obj, x)
119-
return sum((abs2(_evaluate(tensors, indexset) - obj.values[i]) for (i, indexset) in enumerate(obj.indexsets)))
131+
return sum((abs2(_evaluate(tensors, indexset) - obj.values[i]) for (i, indexset) in enumerate(obj.indexsets)))
120132
end

test/test_tensortrain.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using Optim
2323
end
2424
@test gsum TCI.sum(tt)
2525

26-
for method in [:LU] #, :SVD, :CI]
26+
for method in [:LU, :CI, :SVD]
2727
ttcompressed = deepcopy(tt)
2828
TCI.recompress!(ttcompressed; maxbonddim=5, tolerance=1e-2, method)
2929
@test TCI.rank(ttcompressed) <= 5

0 commit comments

Comments
 (0)