@@ -69,3 +69,108 @@ function TensorCI2{ValueType}(tci1::TensorCI1{ValueType}) where {ValueType}
69
69
tci2. maxsamplevalue = tci1. maxsamplevalue
70
70
return tci2
71
71
end
72
+
73
+ function sweep1sitegetindices! (
74
+ tt:: TensorTrain{ValueType,3} , forwardsweep:: Bool ,
75
+ spectatorindices:: Vector{Vector{MultiIndex}} = Vector{MultiIndex}[];
76
+ maxbonddim= typemax (Int), tolerance= 0.0
77
+ ) where {ValueType}
78
+ indexset = Vector{MultiIndex}[MultiIndex[[]]]
79
+ pivoterrorsarray = zeros (rank (tt) + 1 )
80
+
81
+ function groupindices (T:: AbstractArray , next:: Bool )
82
+ shape = size (T)
83
+ if forwardsweep != next
84
+ reshape (T, prod (shape[1 : end - 1 ]), shape[end ])
85
+ else
86
+ reshape (T, shape[1 ], prod (shape[2 : end ]))
87
+ end
88
+ end
89
+
90
+ function splitindices (T:: AbstractArray , shape, newbonddim, next:: Bool )
91
+ if forwardsweep != next
92
+ newshape = (shape[1 : end - 1 ]. .. , newbonddim)
93
+ else
94
+ newshape = (newbonddim, shape[2 : end ]. .. )
95
+ end
96
+ reshape (T, newshape)
97
+ end
98
+
99
+ L = length (tt)
100
+ for i in 1 : L- 1
101
+ ell = forwardsweep ? i : L - i + 1
102
+ ellnext = forwardsweep ? i + 1 : L - i
103
+ shape = size (tt. sitetensors[ell])
104
+ shapenext = size (tt. sitetensors[ellnext])
105
+
106
+ luci = MatrixLUCI (
107
+ groupindices (tt. sitetensors[ell], false ), leftorthogonal= forwardsweep,
108
+ abstol= tolerance, maxrank= maxbonddim
109
+ )
110
+
111
+ if forwardsweep
112
+ push! (indexset, kronecker (last (indexset), shape[2 ])[rowindices (luci)])
113
+ if ! isempty (spectatorindices)
114
+ spectatorindices[ell] = spectatorindices[ell][colindices (luci)]
115
+ end
116
+ else
117
+ push! (indexset, kronecker (shape[2 ], last (indexset))[colindices (luci)])
118
+ if ! isempty (spectatorindices)
119
+ spectatorindices[ell] = spectatorindices[ell][rowindices (luci)]
120
+ end
121
+ end
122
+
123
+
124
+ tt. sitetensors[ell] = splitindices (
125
+ forwardsweep ? left (luci) : right (luci),
126
+ shape, npivots (luci), false
127
+ )
128
+
129
+ nexttensor = (
130
+ forwardsweep
131
+ ? right (luci) * groupindices (tt. sitetensors[ellnext], true )
132
+ : groupindices (tt. sitetensors[ellnext], true ) * left (luci)
133
+ )
134
+
135
+ tt. sitetensors[ellnext] = splitindices (nexttensor, shapenext, npivots (luci), true )
136
+ pivoterrorsarray[1 : npivots (luci) + 1 ] = max .(pivoterrorsarray[1 : npivots (luci) + 1 ], pivoterrors (luci))
137
+ end
138
+
139
+ if forwardsweep
140
+ return indexset, pivoterrorsarray
141
+ else
142
+ return reverse (indexset), pivoterrorsarray
143
+ end
144
+ end
145
+
146
+ function TensorCI2 {ValueType} (
147
+ tt:: TensorTrain{ValueType,3} ; tolerance= 1e-12 , maxbonddim= typemax (Int), maxiter= 3
148
+ ) where {ValueType}
149
+ local pivoterrors:: Vector{Float64}
150
+
151
+ Iset, = sweep1sitegetindices! (tt, true ; maxbonddim, tolerance)
152
+ Jset, pivoterrors = sweep1sitegetindices! (tt, false ; maxbonddim, tolerance)
153
+
154
+ for iter in 3 : maxiter
155
+ if isodd (iter)
156
+ Isetnew, pivoterrors = sweep1sitegetindices! (tt, true , Jset)
157
+ if Isetnew == Iset
158
+ break
159
+ end
160
+ else
161
+ Jsetnew, pivoterrors = sweep1sitegetindices! (tt, false , Iset)
162
+ if Jsetnew == Jset
163
+ break
164
+ end
165
+ end
166
+ end
167
+
168
+ tci2 = TensorCI2 {ValueType} (first .(sitedims (tt)))
169
+ tci2. Iset = Iset
170
+ tci2. Jset = Jset
171
+ tci2. sitetensors = sitetensors (tt)
172
+ tci2. pivoterrors = pivoterrors
173
+ tci2. maxsamplevalue = maximum (maximum .(abs, tci2. sitetensors))
174
+
175
+ return tci2
176
+ end
0 commit comments