@@ -62,9 +62,9 @@ Arguments:
62
62
- `localdims`: a vector of local dimensions for each tensor in the tensor train. A each element
63
63
of `localdims` should be an array-like object of `N-2` integers.
64
64
"""
65
- function TensorTrain {N} (tt:: AbstractTensorTrain{V} , localdims):: TensorTrain{V,N} where {V,N}
65
+ function TensorTrain {V, N} (tt:: AbstractTensorTrain{V} , localdims):: TensorTrain{V,N} where {V,N}
66
66
for d in localdims
67
- length (d) == N- 2 || error (" Each element of localdims be a list of N-2 integers." )
67
+ length (d) == N - 2 || error (" Each element of localdims be a list of N-2 integers." )
68
68
end
69
69
for n in 1 : length (tt)
70
70
prod (size (tt[n])[2 : end - 1 ]) == prod (localdims[n]) || error (" The local dimensions at n=$n must match the tensor sizes." )
@@ -73,13 +73,17 @@ function TensorTrain{N}(tt::AbstractTensorTrain{V}, localdims)::TensorTrain{V,N}
73
73
[reshape (t, size (t, 1 ), localdims[n]. .. , size (t)[end ]) for (n, t) in enumerate (sitetensors (tt))])
74
74
end
75
75
76
+ function TensorTrain {N} (tt:: AbstractTensorTrain{V} , localdims):: TensorTrain{V,N} where {V,N}
77
+ return TensorTrain {V,N} (tt, localdims)
78
+ end
79
+
76
80
function tensortrain (tci)
77
81
return TensorTrain (tci)
78
82
end
79
83
80
84
function _factorize (
81
85
A:: Matrix{V} , method:: Symbol ; tolerance:: Float64 , maxbonddim:: Int
82
- ):: Tuple{Matrix{V}, Matrix{V}, Int} where {V}
86
+ ):: Tuple{Matrix{V},Matrix{V},Int} where {V}
83
87
if method === :LU
84
88
factorization = rrlu (A, abstol= tolerance, maxrank= maxbonddim)
85
89
return left (factorization), right (factorization), npivots (factorization)
@@ -113,11 +117,11 @@ end
113
117
Compress the tensor train `tt` using `LU`, `CI` or `SVD` decompositions.
114
118
"""
115
119
function compress! (
116
- tt:: TensorTrain{V, N} ,
120
+ tt:: TensorTrain{V,N} ,
117
121
method:: Symbol = :LU ;
118
122
tolerance:: Float64 = 1e-12 ,
119
123
maxbonddim:: Int = typemax (Int)
120
- ) where {V, N}
124
+ ) where {V,N}
121
125
for ell in 1 : length (tt)- 1
122
126
shapel = size (tt. sitetensors[ell])
123
127
left, right, newbonddim = _factorize (
@@ -146,48 +150,48 @@ function compress!(
146
150
end
147
151
148
152
149
- function multiply! (tt:: TensorTrain{V, N} , a) where {V, N}
153
+ function multiply! (tt:: TensorTrain{V,N} , a) where {V,N}
150
154
tt. sitetensors[end ] .= tt. sitetensors[end ] .* a
151
155
nothing
152
156
end
153
157
154
- function multiply! (a, tt:: TensorTrain{V, N} ) where {V, N}
158
+ function multiply! (a, tt:: TensorTrain{V,N} ) where {V,N}
155
159
tt. sitetensors[end ] .= a .* tt. sitetensors[end ]
156
160
nothing
157
161
end
158
162
159
- function multiply (tt:: TensorTrain{V, N} , a):: TensorTrain{V, N} where {V, N}
163
+ function multiply (tt:: TensorTrain{V,N} , a):: TensorTrain{V,N} where {V,N}
160
164
tt2 = deepcopy (tt)
161
165
multiply! (tt2, a)
162
166
return tt2
163
167
end
164
168
165
- function multiply (a, tt:: TensorTrain{V, N} ):: TensorTrain{V, N} where {V, N}
169
+ function multiply (a, tt:: TensorTrain{V,N} ):: TensorTrain{V,N} where {V,N}
166
170
tt2 = deepcopy (tt)
167
171
multiply! (a, tt2)
168
172
return tt2
169
173
end
170
174
171
- function Base.:* (tt:: TensorTrain{V, N} , a):: TensorTrain{V, N} where {V, N}
175
+ function Base.:* (tt:: TensorTrain{V,N} , a):: TensorTrain{V,N} where {V,N}
172
176
return multiply (tt, a)
173
177
end
174
178
175
- function Base.:* (a, tt:: TensorTrain{V, N} ):: TensorTrain{V, N} where {V, N}
179
+ function Base.:* (a, tt:: TensorTrain{V,N} ):: TensorTrain{V,N} where {V,N}
176
180
return multiply (a, tt)
177
181
end
178
182
179
- function divide! (tt:: TensorTrain{V, N} , a) where {V, N}
183
+ function divide! (tt:: TensorTrain{V,N} , a) where {V,N}
180
184
tt. sitetensors[end ] .= tt. sitetensors[end ] ./ a
181
185
nothing
182
186
end
183
187
184
- function divide (tt:: TensorTrain{V, N} , a) where {V, N}
188
+ function divide (tt:: TensorTrain{V,N} , a) where {V,N}
185
189
tt2 = deepcopy (tt)
186
190
divide! (tt2, a)
187
191
return tt2
188
192
end
189
193
190
- function Base.:/ (tt:: TensorTrain{V, N} , a) where {V, N}
194
+ function Base.:/ (tt:: TensorTrain{V,N} , a) where {V,N}
191
195
return divide (tt, a)
192
196
end
193
197
@@ -222,7 +226,7 @@ function to_tensors(obj::TensorTrainFit{ValueType}, x::Vector{ValueType}) where
222
226
]
223
227
end
224
228
225
- function _evaluate (tt:: Vector{Array{V, 3}} , indexset) where {V}
229
+ function _evaluate (tt:: Vector{Array{V,3}} , indexset) where {V}
226
230
only (prod (T[:, i, :] for (T, i) in zip (tt, indexset)))
227
231
end
228
232
0 commit comments