@@ -61,21 +61,33 @@ function recompress!(
61
61
tolerance:: Float64 = 1e-12 , maxbonddim= typemax (Int),
62
62
method:: Symbol = :LU
63
63
) where {V}
64
- if method != = :LU
65
- error (" Not implemented yet." )
64
+ function factorize (A:: Matrix{V} )
65
+ if method === :LU
66
+ factorization = rrlu (A, abstol= tolerance, maxrank= maxbonddim)
67
+ return left (factorization), right (factorization), npivots (factorization)
68
+ elseif method === :CI
69
+ factorization = MatrixLUCI (A, abstol= tolerance, maxrank= maxbonddim)
70
+ return left (factorization), right (factorization), npivots (factorization)
71
+ elseif method === :SVD
72
+ factorization = LinearAlgebra. svd (A)
73
+ trunci = min (findlast (> (tolerance), factorization. S), maxbonddim)
74
+ return (
75
+ factorization. U[:, 1 : trunci],
76
+ Diagonal (factorization. S[1 : trunci]) * factorization. Vt[1 : trunci, :],
77
+ trunci
78
+ )
79
+ else
80
+ error (" Not implemented yet." )
81
+ end
66
82
end
67
83
68
84
for ell in 1 : length (tt)- 1
69
85
shapel = size (tt. sitetensors[ell])
70
- lu = rrlu (
71
- reshape (tt. sitetensors[ell], prod (shapel[1 : end - 1 ]), shapel[end ]);
72
- abstol= tolerance,
73
- maxrank= maxbonddim
74
- )
75
- tt. sitetensors[ell] = reshape (left (lu), shapel[1 : end - 1 ]. .. , npivots (lu))
86
+ left, right, newbonddim = factorize (reshape (tt. sitetensors[ell], prod (shapel[1 : end - 1 ]), shapel[end ]))
87
+ tt. sitetensors[ell] = reshape (left, shapel[1 : end - 1 ]. .. , newbonddim)
76
88
shaper = size (tt. sitetensors[ell+ 1 ])
77
- nexttensor = right (lu) * reshape (tt. sitetensors[ell+ 1 ], shaper[1 ], prod (shaper[2 : end ]))
78
- tt. sitetensors[ell+ 1 ] = reshape (nexttensor, npivots (lu) , shaper[2 : end ]. .. )
89
+ nexttensor = right * reshape (tt. sitetensors[ell+ 1 ], shaper[1 ], prod (shaper[2 : end ]))
90
+ tt. sitetensors[ell+ 1 ] = reshape (nexttensor, newbonddim , shaper[2 : end ]. .. )
79
91
end
80
92
nothing
81
93
end
@@ -107,7 +119,7 @@ function to_tensors(obj::TensorTrainFit{ValueType}, x::Vector{ValueType}) where
107
119
reshape (
108
120
x[obj. offsets[n]+ 1 : obj. offsets[n+ 1 ]],
109
121
size (obj. tt[n])
110
- )
122
+ )
111
123
for n in 1 : length (obj. tt)
112
124
]
113
125
end
@@ -116,5 +128,5 @@ _evaluate(tt, indexset) = only(prod(T[:, i, :] for (T, i) in zip(tt, indexset)))
116
128
117
129
function (obj:: TensorTrainFit{ValueType} )(x:: Vector{ValueType} ) where {ValueType}
118
130
tensors = to_tensors (obj, x)
119
- return sum ((abs2 (_evaluate (tensors, indexset) - obj. values[i]) for (i, indexset) in enumerate (obj. indexsets)))
131
+ return sum ((abs2 (_evaluate (tensors, indexset) - obj. values[i]) for (i, indexset) in enumerate (obj. indexsets)))
120
132
end
0 commit comments