@@ -82,25 +82,50 @@ function tensortrain(tci)
82
82
end
83
83
84
84
function _factorize (
85
- A:: Matrix{V} , method:: Symbol ; tolerance:: Float64 , maxbonddim:: Int
85
+ A:: Matrix{V} , method:: Symbol ; tolerance:: Float64 , maxbonddim:: Int , leftorthogonal :: Bool = false , normalizeerror = true
86
86
):: Tuple{Matrix{V},Matrix{V},Int} where {V}
87
+ reltol = 1e-14
88
+ abstol = 0.0
89
+ if normalizeerror
90
+ reltol = tolerance
91
+ else
92
+ abstol = tolerance
93
+ end
87
94
if method === :LU
88
- factorization = rrlu (A, abstol= tolerance, maxrank= maxbonddim)
95
+ factorization = rrlu (A, abstol= abstol, reltol = reltol, maxrank= maxbonddim, leftorthogonal = leftorthogonal )
89
96
return left (factorization), right (factorization), npivots (factorization)
90
97
elseif method === :CI
91
- factorization = MatrixLUCI (A, abstol= tolerance, maxrank= maxbonddim)
98
+ factorization = MatrixLUCI (A, abstol= abstol, reltol = reltol, maxrank= maxbonddim, leftorthogonal = leftorthogonal )
92
99
return left (factorization), right (factorization), npivots (factorization)
93
100
elseif method === :SVD
94
101
factorization = LinearAlgebra. svd (A)
102
+ err = [sum (factorization. S[n+ 1 : end ] .^ 2 ) for n in 1 : length (factorization. S)]
103
+ normalized_err = err ./ sum (factorization. S .^ 2 )
104
+
105
+ # @show normalized_err
106
+ # @show sum(factorization.S .^ 2)
107
+ # @show err
95
108
trunci = min (
96
- replacenothing (findlast (> (tolerance), factorization. S), 1 ),
109
+ replacenothing (findfirst (< (abstol^ 2 ), err), length (err)),
110
+ replacenothing (findfirst (< (reltol^ 2 ), normalized_err), length (normalized_err)),
97
111
maxbonddim
98
112
)
99
- return (
100
- factorization. U[:, 1 : trunci],
101
- Diagonal (factorization. S[1 : trunci]) * factorization. Vt[1 : trunci, :],
102
- trunci
103
- )
113
+ # @show findfirst(<(abstol^2), err)
114
+ # @show findfirst(<(reltol^2), normalized_err)
115
+ # @show trunci, length(err)
116
+ if leftorthogonal
117
+ return (
118
+ factorization. U[:, 1 : trunci],
119
+ Diagonal (factorization. S[1 : trunci]) * factorization. Vt[1 : trunci, :],
120
+ trunci
121
+ )
122
+ else
123
+ return (
124
+ factorization. U[:, 1 : trunci] * Diagonal (factorization. S[1 : trunci]),
125
+ factorization. Vt[1 : trunci, :],
126
+ trunci
127
+ )
128
+ end
104
129
else
105
130
error (" Not implemented yet." )
106
131
end
@@ -120,32 +145,41 @@ function compress!(
120
145
tt:: TensorTrain{V,N} ,
121
146
method:: Symbol = :LU ;
122
147
tolerance:: Float64 = 1e-12 ,
123
- maxbonddim:: Int = typemax (Int)
148
+ maxbonddim:: Int = typemax (Int),
149
+ normalizeerror:: Bool = true
124
150
) where {V,N}
151
+ # From left to right
125
152
for ell in 1 : length (tt)- 1
153
+ # println("ell=$ell")
126
154
shapel = size (tt. sitetensors[ell])
127
155
left, right, newbonddim = _factorize (
128
156
reshape (tt. sitetensors[ell], prod (shapel[1 : end - 1 ]), shapel[end ]),
129
- method; tolerance, maxbonddim
157
+ method; tolerance= 0.0 , maxbonddim= typemax (Int), leftorthogonal = true # no truncation
130
158
)
131
159
tt. sitetensors[ell] = reshape (left, shapel[1 : end - 1 ]. .. , newbonddim)
132
160
shaper = size (tt. sitetensors[ell+ 1 ])
133
161
nexttensor = right * reshape (tt. sitetensors[ell+ 1 ], shaper[1 ], prod (shaper[2 : end ]))
134
162
tt. sitetensors[ell+ 1 ] = reshape (nexttensor, newbonddim, shaper[2 : end ]. .. )
135
163
end
136
164
165
+ # From right to left
137
166
for ell in length (tt): - 1 : 2
138
167
shaper = size (tt. sitetensors[ell])
139
168
left, right, newbonddim = _factorize (
140
169
reshape (tt. sitetensors[ell], shaper[1 ], prod (shaper[2 : end ])),
141
- method; tolerance, maxbonddim
170
+ method; tolerance, maxbonddim, normalizeerror, leftorthogonal = false
142
171
)
143
172
tt. sitetensors[ell] = reshape (right, newbonddim, shaper[2 : end ]. .. )
144
173
shapel = size (tt. sitetensors[ell- 1 ])
145
174
nexttensor = reshape (tt. sitetensors[ell- 1 ], prod (shapel[1 : end - 1 ]), shapel[end ]) * left
146
175
tt. sitetensors[ell- 1 ] = reshape (nexttensor, shapel[1 : end - 1 ]. .. , newbonddim)
147
176
end
148
177
178
+ # println("")
179
+ # println("")
180
+ # println("")
181
+ # println("")
182
+ # println("")
149
183
nothing
150
184
end
151
185
@@ -201,6 +235,7 @@ function Base.reverse(tt::AbstractTensorTrain{V}) where {V}
201
235
]))
202
236
end
203
237
238
+
204
239
"""
205
240
Fitting data with a TensorTrain object.
206
241
This may be useful when the interpolated function is noisy.
0 commit comments