@@ -52,13 +52,38 @@ function TensorTrain(tci::AbstractTensorTrain{V})::TensorTrain{V,3} where {V}
52
52
return TensorTrain {V,3} (sitetensors (tci))
53
53
end
54
54
55
+ """
56
+ function TensorTrain{N}(tci::AbstractTensorTrain{V}) where {V,N}
57
+
58
+ Convert a tensor-train-like object into a tensor train.
59
+
60
+ Arguments:
61
+ - `tt::AbstractTensorTrain{V}`: a tensor-train-like object.
62
+ - `localdims`: a vector of local dimensions for each tensor in the tensor train. A each element
63
+ of `localdims` should be an array-like object of `N-2` integers.
64
+ """
65
+ function TensorTrain {V,N} (tt:: AbstractTensorTrain{V} , localdims):: TensorTrain{V,N} where {V,N}
66
+ for d in localdims
67
+ length (d) == N - 2 || error (" Each element of localdims be a list of N-2 integers." )
68
+ end
69
+ for n in 1 : length (tt)
70
+ prod (size (tt[n])[2 : end - 1 ]) == prod (localdims[n]) || error (" The local dimensions at n=$n must match the tensor sizes." )
71
+ end
72
+ return TensorTrain {V,N} (
73
+ [reshape (t, size (t, 1 ), localdims[n]. .. , size (t)[end ]) for (n, t) in enumerate (sitetensors (tt))])
74
+ end
75
+
76
+ function TensorTrain {N} (tt:: AbstractTensorTrain{V} , localdims):: TensorTrain{V,N} where {V,N}
77
+ return TensorTrain {V,N} (tt, localdims)
78
+ end
79
+
55
80
function tensortrain (tci)
56
81
return TensorTrain (tci)
57
82
end
58
83
59
84
function _factorize (
60
85
A:: Matrix{V} , method:: Symbol ; tolerance:: Float64 , maxbonddim:: Int
61
- ):: Tuple{Matrix{V}, Matrix{V}, Int} where {V}
86
+ ):: Tuple{Matrix{V},Matrix{V},Int} where {V}
62
87
if method === :LU
63
88
factorization = rrlu (A, abstol= tolerance, maxrank= maxbonddim)
64
89
return left (factorization), right (factorization), npivots (factorization)
92
117
Compress the tensor train `tt` using `LU`, `CI` or `SVD` decompositions.
93
118
"""
94
119
function compress! (
95
- tt:: TensorTrain{V, N} ,
120
+ tt:: TensorTrain{V,N} ,
96
121
method:: Symbol = :LU ;
97
122
tolerance:: Float64 = 1e-12 ,
98
123
maxbonddim:: Int = typemax (Int)
99
- ) where {V, N}
124
+ ) where {V,N}
100
125
for ell in 1 : length (tt)- 1
101
126
shapel = size (tt. sitetensors[ell])
102
127
left, right, newbonddim = _factorize (
@@ -125,48 +150,48 @@ function compress!(
125
150
end
126
151
127
152
128
- function multiply! (tt:: TensorTrain{V, N} , a) where {V, N}
153
+ function multiply! (tt:: TensorTrain{V,N} , a) where {V,N}
129
154
tt. sitetensors[end ] .= tt. sitetensors[end ] .* a
130
155
nothing
131
156
end
132
157
133
- function multiply! (a, tt:: TensorTrain{V, N} ) where {V, N}
158
+ function multiply! (a, tt:: TensorTrain{V,N} ) where {V,N}
134
159
tt. sitetensors[end ] .= a .* tt. sitetensors[end ]
135
160
nothing
136
161
end
137
162
138
- function multiply (tt:: TensorTrain{V, N} , a):: TensorTrain{V, N} where {V, N}
163
+ function multiply (tt:: TensorTrain{V,N} , a):: TensorTrain{V,N} where {V,N}
139
164
tt2 = deepcopy (tt)
140
165
multiply! (tt2, a)
141
166
return tt2
142
167
end
143
168
144
- function multiply (a, tt:: TensorTrain{V, N} ):: TensorTrain{V, N} where {V, N}
169
+ function multiply (a, tt:: TensorTrain{V,N} ):: TensorTrain{V,N} where {V,N}
145
170
tt2 = deepcopy (tt)
146
171
multiply! (a, tt2)
147
172
return tt2
148
173
end
149
174
150
- function Base.:* (tt:: TensorTrain{V, N} , a):: TensorTrain{V, N} where {V, N}
175
+ function Base.:* (tt:: TensorTrain{V,N} , a):: TensorTrain{V,N} where {V,N}
151
176
return multiply (tt, a)
152
177
end
153
178
154
- function Base.:* (a, tt:: TensorTrain{V, N} ):: TensorTrain{V, N} where {V, N}
179
+ function Base.:* (a, tt:: TensorTrain{V,N} ):: TensorTrain{V,N} where {V,N}
155
180
return multiply (a, tt)
156
181
end
157
182
158
- function divide! (tt:: TensorTrain{V, N} , a) where {V, N}
183
+ function divide! (tt:: TensorTrain{V,N} , a) where {V,N}
159
184
tt. sitetensors[end ] .= tt. sitetensors[end ] ./ a
160
185
nothing
161
186
end
162
187
163
- function divide (tt:: TensorTrain{V, N} , a) where {V, N}
188
+ function divide (tt:: TensorTrain{V,N} , a) where {V,N}
164
189
tt2 = deepcopy (tt)
165
190
divide! (tt2, a)
166
191
return tt2
167
192
end
168
193
169
- function Base.:/ (tt:: TensorTrain{V, N} , a) where {V, N}
194
+ function Base.:/ (tt:: TensorTrain{V,N} , a) where {V,N}
170
195
return divide (tt, a)
171
196
end
172
197
@@ -201,7 +226,7 @@ function to_tensors(obj::TensorTrainFit{ValueType}, x::Vector{ValueType}) where
201
226
]
202
227
end
203
228
204
- function _evaluate (tt:: Vector{Array{V, 3}} , indexset) where {V}
229
+ function _evaluate (tt:: Vector{Array{V,3}} , indexset) where {V}
205
230
only (prod (T[:, i, :] for (T, i) in zip (tt, indexset)))
206
231
end
207
232
0 commit comments