Skip to content

Commit 8a77863

Browse files
author
Marc Ritter
committed
Merge branch 'tensor_train_arithmetics' into 'main'
Tensor train arithmetics See merge request tensors4fields/TensorCrossInterpolation.jl!80
2 parents 64969ec + 990b1c8 commit 8a77863

File tree

3 files changed

+118
-14
lines changed

3 files changed

+118
-14
lines changed

src/abstracttensortrain.jl

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,11 @@ function sum(tt::AbstractTensorTrain{V}) where {V}
172172
return only(v)
173173
end
174174

175-
function _addtttensor(A::Array{V}, B::Array{V}; lefttensor=false, righttensor=false) where {V}
175+
function _addtttensor(
176+
A::Array{V}, B::Array{V};
177+
factorA=one(V), factorB=one(V),
178+
lefttensor=false, righttensor=false
179+
) where {V}
176180
if ndims(A) != ndims(B)
177181
throw(DimensionMismatch("Elementwise addition only works if both tensors have the same indices, but A and B have different numbers ($(ndims(A)) and $(ndims(B))) of indices."))
178182
end
@@ -181,38 +185,84 @@ function _addtttensor(A::Array{V}, B::Array{V}; lefttensor=false, righttensor=fa
181185
offset3 = righttensor ? 0 : size(A, nd)
182186
localindices = fill(Colon(), nd - 2)
183187
C = zeros(V, offset1 + size(B, 1), size(A)[2:nd-1]..., offset3 + size(B, nd))
184-
C[1:size(A, 1), localindices..., 1:size(A, nd)] = A
185-
C[offset1+1:end, localindices..., offset3+1:end] = B
188+
C[1:size(A, 1), localindices..., 1:size(A, nd)] = factorA * A
189+
C[offset1+1:end, localindices..., offset3+1:end] = factorB * B
186190
return C
187191
end
188192

189193
@doc raw"""
190-
function add(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
194+
function add(
195+
lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V};
196+
factorlhs=one(V), factorrhs=one(V),
197+
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int)
198+
) where {V}
191199
192-
Addition of two tensor trains. If `c = add(a, b)`, then `c(v) ≈ a(v) + b(v)` at each index set `v`. Note that this function increases the bond dimension, i.e. ``\chi_{\text{result}} = \chi_1 + \chi_2`` if the original tensor trains had bond dimensions ``\chi_1`` and ``\chi_2``. In many cases, it is advisable to recompress/truncate the resulting tensor train afterwards.
200+
Addition of two tensor trains. If `C = add(A, B)`, then `C(v) ≈ A(v) + B(v)` at each index set `v`. Note that this function increases the bond dimension, i.e. ``\chi_{\text{result}} = \chi_1 + \chi_2`` if the original tensor trains had bond dimensions ``\chi_1`` and ``\chi_2``.
201+
202+
Arguments:
203+
- `lhs`, `rhs`: Tensor trains to be added.
204+
- `factorlhs`, `factorrhs`: Factors to multiply each tensor train by before addition.
205+
- `tolerance`, `maxbonddim`: Parameters to be used for the recompression step.
206+
207+
Returns:
208+
A new `TensorTrain` representing the function `factorlhs * lhs(v) + factorrhs * rhs(v)`.
193209
194210
See also: [`+`](@ref)
195211
"""
196-
function add(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
212+
function add(
213+
lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V};
214+
factorlhs=one(V), factorrhs=one(V),
215+
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int)
216+
) where {V}
197217
if length(lhs) != length(rhs)
198218
throw(DimensionMismatch("Two tensor trains with different length ($(length(lhs)) and $(length(rhs))) cannot be added elementwise."))
199219
end
200220
L = length(lhs)
201-
return tensortrain(
221+
tt = tensortrain(
202222
[
203-
_addtttensor(lhs[ell], rhs[ell]; lefttensor=(ell==1), righttensor=(ell==L))
223+
_addtttensor(
224+
lhs[ell], rhs[ell];
225+
factorA=((ell == L) ? factorlhs : one(V)),
226+
factorB=((ell == L) ? factorrhs : one(V)),
227+
lefttensor=(ell==1),
228+
righttensor=(ell==L)
229+
)
204230
for ell in 1:L
205231
]
206232
)
233+
compress!(tt, :SVD; tolerance, maxbonddim)
234+
return tt
207235
end
208236

209237
@doc raw"""
210-
function (+)(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
238+
function subtract(
239+
lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V};
240+
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int)
241+
)
242+
243+
Subtract two tensor trains `lhs` and `rhs`. See [`add`](@ref).
244+
"""
245+
function subtract(
246+
lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V};
247+
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int)
248+
) where {V}
249+
return add(lhs, rhs; factorrhs=-1 * one(V), tolerance, maxbonddim)
250+
end
211251

212-
Addition of two tensor trains. If `c = a + b`, then `c(v) ≈ a(v) + b(v)` at each index set `v`. Note that this function increases the bond dimension, i.e. ``\chi_{\text{result}} = \chi_1 + \chi_2`` if the original tensor trains had bond dimensions ``\chi_1`` and ``\chi_2``. In many cases, it is advisable to recompress/truncate the resulting tensor train afterwards.
252+
@doc raw"""
253+
function (+)(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
213254
214-
See also: [`add`](@ref)
255+
Addition of two tensor trains. If `c = a + b`, then `c(v) ≈ a(v) + b(v)` at each index set `v`. Note that this function increases the bond dimension, i.e. ``\chi_{\text{result}} = \chi_1 + \chi_2`` if the original tensor trains had bond dimensions ``\chi_1`` and ``\chi_2``. Can be combined with automatic recompression by calling [`add`](@ref).
215256
"""
216-
function (+)(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
257+
function Base.:+(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
217258
return add(lhs, rhs)
218259
end
260+
261+
@doc raw"""
262+
function (-)(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
263+
264+
Subtraction of two tensor trains. If `c = a - b`, then `c(v) ≈ a(v) - b(v)` at each index set `v`. Note that this function increases the bond dimension, i.e. ``\chi_{\text{result}} = \chi_1 + \chi_2`` if the original tensor trains had bond dimensions ``\chi_1`` and ``\chi_2``. Can be combined with automatic recompression by calling [`subtract`](@ref) (see documentation for [`add`](@ref)).
265+
"""
266+
function Base.:-(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
267+
return subtract(lhs, rhs)
268+
end

src/tensortrain.jl

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ function compress!(
9595
tt::TensorTrain{V, N},
9696
method::Symbol=:LU;
9797
tolerance::Float64=1e-12,
98-
maxbonddim=typemax(Int)
98+
maxbonddim::Int=typemax(Int)
9999
) where {V, N}
100100
for ell in 1:length(tt)-1
101101
shapel = size(tt.sitetensors[ell])
@@ -125,6 +125,51 @@ function compress!(
125125
end
126126

127127

128+
function multiply!(tt::TensorTrain{V, N}, a) where {V, N}
129+
tt.sitetensors[end] .= tt.sitetensors[end] .* a
130+
nothing
131+
end
132+
133+
function multiply!(a, tt::TensorTrain{V, N}) where {V, N}
134+
tt.sitetensors[end] .= a .* tt.sitetensors[end]
135+
nothing
136+
end
137+
138+
function multiply(tt::TensorTrain{V, N}, a)::TensorTrain{V, N} where {V, N}
139+
tt2 = deepcopy(tt)
140+
multiply!(tt2, a)
141+
return tt2
142+
end
143+
144+
function multiply(a, tt::TensorTrain{V, N})::TensorTrain{V, N} where {V, N}
145+
tt2 = deepcopy(tt)
146+
multiply!(a, tt2)
147+
return tt2
148+
end
149+
150+
function Base.:*(tt::TensorTrain{V, N}, a)::TensorTrain{V, N} where {V, N}
151+
return multiply(tt, a)
152+
end
153+
154+
function Base.:*(a, tt::TensorTrain{V, N})::TensorTrain{V, N} where {V, N}
155+
return multiply(a, tt)
156+
end
157+
158+
function divide!(tt::TensorTrain{V, N}, a) where {V, N}
159+
tt.sitetensors[end] .= tt.sitetensors[end] ./ a
160+
nothing
161+
end
162+
163+
function divide(tt::TensorTrain{V, N}, a) where {V, N}
164+
tt2 = deepcopy(tt)
165+
divide!(tt2, a)
166+
return tt2
167+
end
168+
169+
function Base.:/(tt::TensorTrain{V, N}, a) where {V, N}
170+
return divide(tt, a)
171+
end
172+
128173
"""
129174
Fitting data with a TensorTrain object.
130175
This may be useful when the interpolated function is noisy.

test/test_tensortrain.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ end
119119
@test [TCI.evaluate(ttopt, idx) for (idx, v) in zip(indexsets, values)] values
120120
end
121121

122-
@testset "tensor train addition" for T in [Float64, ComplexF64]
122+
@testset "tensor train addition and multiplication" for T in [Float64, ComplexF64]
123123
Random.seed!(10)
124124
localdims = [2, 2, 2]
125125
linkdims = [1, 2, 3, 1]
@@ -134,6 +134,15 @@ end
134134
ttadd2 = tt1 + tt2
135135
@test ttadd2.(indices) [tt1(v) + tt2(v) for v in indices]
136136

137+
tt1mul = 1.6 * tt1
138+
@test tt1mul.(indices) 1.6 .* tt1.(indices)
139+
140+
tt1div = tt1mul / 3.2
141+
@test tt1div.(indices) tt1.(indices) ./ 2.0
142+
143+
tt1sub = tt1 - tt1div
144+
@test tt1sub.(indices) tt1.(indices) ./ 2.0
145+
137146
ttshort = TCI.TensorTrain{T,3}([randn(T, linkdims[n], localdims[n], linkdims[n+1]) for n in 1:L-1])
138147
@test_throws DimensionMismatch TCI.add(tt1, ttshort)
139148

0 commit comments

Comments
 (0)