@@ -56,40 +56,61 @@ function tensortrain(tci)
56
56
return TensorTrain (tci)
57
57
end
58
58
59
+ function _factorize (
60
+ A:: Matrix{V} , method:: Symbol ; tolerance:: Float64 , maxbonddim:: Int
61
+ ) where {V}
62
+ if method === :LU
63
+ factorization = rrlu (A, abstol= tolerance, maxrank= maxbonddim)
64
+ return left (factorization), right (factorization), npivots (factorization)
65
+ elseif method === :CI
66
+ factorization = MatrixLUCI (A, abstol= tolerance, maxrank= maxbonddim)
67
+ return left (factorization), right (factorization), npivots (factorization)
68
+ elseif method === :SVD
69
+ factorization = LinearAlgebra. svd (A)
70
+ trunci = min (
71
+ replacenothing (findlast (> (tolerance), factorization. S), 1 ),
72
+ maxbonddim
73
+ )
74
+ return (
75
+ factorization. U[:, 1 : trunci],
76
+ Diagonal (factorization. S[1 : trunci]) * factorization. Vt[1 : trunci, :],
77
+ trunci
78
+ )
79
+ else
80
+ error (" Not implemented yet." )
81
+ end
82
+ end
83
+
59
84
function compress! (
60
- tt:: AbstractTensorTrain{V } ,
85
+ tt:: TensorTrain{V, N } ,
61
86
method:: Symbol = :LU ;
62
87
tolerance:: Float64 = 1e-12 ,
63
88
maxbonddim= typemax (Int)
64
- ) where {V}
65
- function factorize (A:: Matrix{V} )
66
- if method === :LU
67
- factorization = rrlu (A, abstol= tolerance, maxrank= maxbonddim)
68
- return left (factorization), right (factorization), npivots (factorization)
69
- elseif method === :CI
70
- factorization = MatrixLUCI (A, abstol= tolerance, maxrank= maxbonddim)
71
- return left (factorization), right (factorization), npivots (factorization)
72
- elseif method === :SVD
73
- factorization = LinearAlgebra. svd (A)
74
- trunci = min (findlast (> (tolerance), factorization. S), maxbonddim)
75
- return (
76
- factorization. U[:, 1 : trunci],
77
- Diagonal (factorization. S[1 : trunci]) * factorization. Vt[1 : trunci, :],
78
- trunci
79
- )
80
- else
81
- error (" Not implemented yet." )
82
- end
83
- end
84
-
89
+ ) where {V, N}
85
90
for ell in 1 : length (tt)- 1
86
91
shapel = size (tt. sitetensors[ell])
87
- left, right, newbonddim = factorize (reshape (tt. sitetensors[ell], prod (shapel[1 : end - 1 ]), shapel[end ]))
92
+ left, right, newbonddim = _factorize (
93
+ reshape (tt. sitetensors[ell], prod (shapel[1 : end - 1 ]), shapel[end ]),
94
+ method; tolerance, maxbonddim
95
+ )
88
96
tt. sitetensors[ell] = reshape (left, shapel[1 : end - 1 ]. .. , newbonddim)
89
97
shaper = size (tt. sitetensors[ell+ 1 ])
90
98
nexttensor = right * reshape (tt. sitetensors[ell+ 1 ], shaper[1 ], prod (shaper[2 : end ]))
91
99
tt. sitetensors[ell+ 1 ] = reshape (nexttensor, newbonddim, shaper[2 : end ]. .. )
92
100
end
101
+
102
+ for ell in length (tt): - 1 : 2
103
+ shaper = size (tt. sitetensors[ell])
104
+ left, right, newbonddim = _factorize (
105
+ reshape (tt. sitetensors[ell], shaper[1 ], prod (shaper[2 : end ])),
106
+ method; tolerance, maxbonddim
107
+ )
108
+ tt. sitetensors[ell] = reshape (right, newbonddim, shaper[2 : end ]. .. )
109
+ shapel = size (tt. sitetensors[ell- 1 ])
110
+ nexttensor = reshape (tt. sitetensors[ell- 1 ], prod (shapel[1 : end - 1 ]), shapel[end ]) * left
111
+ tt. sitetensors[ell- 1 ] = reshape (nexttensor, shapel[1 : end - 1 ]. .. , newbonddim)
112
+ end
113
+
93
114
nothing
94
115
end
95
116
0 commit comments