@@ -56,6 +56,64 @@ function tensortrain(tci)
56
56
return TensorTrain (tci)
57
57
end
58
58
59
+ function _factorize (
60
+ A:: Matrix{V} , method:: Symbol ; tolerance:: Float64 , maxbonddim:: Int
61
+ ) where {V}
62
+ if method === :LU
63
+ factorization = rrlu (A, abstol= tolerance, maxrank= maxbonddim)
64
+ return left (factorization), right (factorization), npivots (factorization)
65
+ elseif method === :CI
66
+ factorization = MatrixLUCI (A, abstol= tolerance, maxrank= maxbonddim)
67
+ return left (factorization), right (factorization), npivots (factorization)
68
+ elseif method === :SVD
69
+ factorization = LinearAlgebra. svd (A)
70
+ trunci = min (
71
+ replacenothing (findlast (> (tolerance), factorization. S), 1 ),
72
+ maxbonddim
73
+ )
74
+ return (
75
+ factorization. U[:, 1 : trunci],
76
+ Diagonal (factorization. S[1 : trunci]) * factorization. Vt[1 : trunci, :],
77
+ trunci
78
+ )
79
+ else
80
+ error (" Not implemented yet." )
81
+ end
82
+ end
83
+
84
+ function compress! (
85
+ tt:: TensorTrain{V, N} ,
86
+ method:: Symbol = :LU ;
87
+ tolerance:: Float64 = 1e-12 ,
88
+ maxbonddim= typemax (Int)
89
+ ) where {V, N}
90
+ for ell in 1 : length (tt)- 1
91
+ shapel = size (tt. sitetensors[ell])
92
+ left, right, newbonddim = _factorize (
93
+ reshape (tt. sitetensors[ell], prod (shapel[1 : end - 1 ]), shapel[end ]),
94
+ method; tolerance, maxbonddim
95
+ )
96
+ tt. sitetensors[ell] = reshape (left, shapel[1 : end - 1 ]. .. , newbonddim)
97
+ shaper = size (tt. sitetensors[ell+ 1 ])
98
+ nexttensor = right * reshape (tt. sitetensors[ell+ 1 ], shaper[1 ], prod (shaper[2 : end ]))
99
+ tt. sitetensors[ell+ 1 ] = reshape (nexttensor, newbonddim, shaper[2 : end ]. .. )
100
+ end
101
+
102
+ for ell in length (tt): - 1 : 2
103
+ shaper = size (tt. sitetensors[ell])
104
+ left, right, newbonddim = _factorize (
105
+ reshape (tt. sitetensors[ell], shaper[1 ], prod (shaper[2 : end ])),
106
+ method; tolerance, maxbonddim
107
+ )
108
+ tt. sitetensors[ell] = reshape (right, newbonddim, shaper[2 : end ]. .. )
109
+ shapel = size (tt. sitetensors[ell- 1 ])
110
+ nexttensor = reshape (tt. sitetensors[ell- 1 ], prod (shapel[1 : end - 1 ]), shapel[end ]) * left
111
+ tt. sitetensors[ell- 1 ] = reshape (nexttensor, shapel[1 : end - 1 ]. .. , newbonddim)
112
+ end
113
+
114
+ nothing
115
+ end
116
+
59
117
60
118
"""
61
119
Fitting data with a TensorTrain object.
@@ -83,7 +141,7 @@ function to_tensors(obj::TensorTrainFit{ValueType}, x::Vector{ValueType}) where
83
141
reshape (
84
142
x[obj. offsets[n]+ 1 : obj. offsets[n+ 1 ]],
85
143
size (obj. tt[n])
86
- )
144
+ )
87
145
for n in 1 : length (obj. tt)
88
146
]
89
147
end
94
152
95
153
function (obj:: TensorTrainFit{ValueType} )(x:: Vector{ValueType} ) where {ValueType}
96
154
tensors = to_tensors (obj, x)
97
- return sum ((abs2 (_evaluate (tensors, indexset) - obj. values[i]) for (i, indexset) in enumerate (obj. indexsets)))
155
+ return sum ((abs2 (_evaluate (tensors, indexset) - obj. values[i]) for (i, indexset) in enumerate (obj. indexsets)))
98
156
end
0 commit comments