@@ -56,6 +56,74 @@ function tensortrain(tci)
56
56
return TensorTrain (tci)
57
57
end
58
58
59
+ function _factorize (
60
+ A:: Matrix{V} , method:: Symbol ; tolerance:: Float64 , maxbonddim:: Int
61
+ ) where {V}
62
+ if method === :LU
63
+ factorization = rrlu (A, abstol= tolerance, maxrank= maxbonddim)
64
+ return left (factorization), right (factorization), npivots (factorization)
65
+ elseif method === :CI
66
+ factorization = MatrixLUCI (A, abstol= tolerance, maxrank= maxbonddim)
67
+ return left (factorization), right (factorization), npivots (factorization)
68
+ elseif method === :SVD
69
+ factorization = LinearAlgebra. svd (A)
70
+ trunci = min (
71
+ replacenothing (findlast (> (tolerance), factorization. S), 1 ),
72
+ maxbonddim
73
+ )
74
+ return (
75
+ factorization. U[:, 1 : trunci],
76
+ Diagonal (factorization. S[1 : trunci]) * factorization. Vt[1 : trunci, :],
77
+ trunci
78
+ )
79
+ else
80
+ error (" Not implemented yet." )
81
+ end
82
+ end
83
+
84
+ """
85
+ function compress!(
86
+ tt::TensorTrain{V, N},
87
+ method::Symbol=:LU;
88
+ tolerance::Float64=1e-12,
89
+ maxbonddim=typemax(Int)
90
+ ) where {V, N}
91
+
92
+ Compress the tensor train `tt` using `LU`, `CI` or `SVD` decompositions.
93
+ """
94
+ function compress! (
95
+ tt:: TensorTrain{V, N} ,
96
+ method:: Symbol = :LU ;
97
+ tolerance:: Float64 = 1e-12 ,
98
+ maxbonddim= typemax (Int)
99
+ ) where {V, N}
100
+ for ell in 1 : length (tt)- 1
101
+ shapel = size (tt. sitetensors[ell])
102
+ left, right, newbonddim = _factorize (
103
+ reshape (tt. sitetensors[ell], prod (shapel[1 : end - 1 ]), shapel[end ]),
104
+ method; tolerance, maxbonddim
105
+ )
106
+ tt. sitetensors[ell] = reshape (left, shapel[1 : end - 1 ]. .. , newbonddim)
107
+ shaper = size (tt. sitetensors[ell+ 1 ])
108
+ nexttensor = right * reshape (tt. sitetensors[ell+ 1 ], shaper[1 ], prod (shaper[2 : end ]))
109
+ tt. sitetensors[ell+ 1 ] = reshape (nexttensor, newbonddim, shaper[2 : end ]. .. )
110
+ end
111
+
112
+ for ell in length (tt): - 1 : 2
113
+ shaper = size (tt. sitetensors[ell])
114
+ left, right, newbonddim = _factorize (
115
+ reshape (tt. sitetensors[ell], shaper[1 ], prod (shaper[2 : end ])),
116
+ method; tolerance, maxbonddim
117
+ )
118
+ tt. sitetensors[ell] = reshape (right, newbonddim, shaper[2 : end ]. .. )
119
+ shapel = size (tt. sitetensors[ell- 1 ])
120
+ nexttensor = reshape (tt. sitetensors[ell- 1 ], prod (shapel[1 : end - 1 ]), shapel[end ]) * left
121
+ tt. sitetensors[ell- 1 ] = reshape (nexttensor, shapel[1 : end - 1 ]. .. , newbonddim)
122
+ end
123
+
124
+ nothing
125
+ end
126
+
59
127
60
128
"""
61
129
Fitting data with a TensorTrain object.
@@ -83,7 +151,7 @@ function to_tensors(obj::TensorTrainFit{ValueType}, x::Vector{ValueType}) where
83
151
reshape (
84
152
x[obj. offsets[n]+ 1 : obj. offsets[n+ 1 ]],
85
153
size (obj. tt[n])
86
- )
154
+ )
87
155
for n in 1 : length (obj. tt)
88
156
]
89
157
end
94
162
95
163
function (obj:: TensorTrainFit{ValueType} )(x:: Vector{ValueType} ) where {ValueType}
96
164
tensors = to_tensors (obj, x)
97
- return sum ((abs2 (_evaluate (tensors, indexset) - obj. values[i]) for (i, indexset) in enumerate (obj. indexsets)))
165
+ return sum ((abs2 (_evaluate (tensors, indexset) - obj. values[i]) for (i, indexset) in enumerate (obj. indexsets)))
98
166
end
0 commit comments