Skip to content

Commit a463259

Browse files
committed
added multiply-add and subtract functionality
1 parent 7f801a9 commit a463259

File tree

2 files changed

+57
-9
lines changed

2 files changed

+57
-9
lines changed

src/abstracttensortrain.jl

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,11 @@ function sum(tt::AbstractTensorTrain{V}) where {V}
172172
return only(v)
173173
end
174174

175-
function _addtttensor(A::Array{V}, B::Array{V}; lefttensor=false, righttensor=false) where {V}
175+
function _addtttensor(
176+
A::Array{V}, B::Array{V};
177+
factorA=one(V), factorB=one(V),
178+
lefttensor=false, righttensor=false
179+
) where {V}
176180
if ndims(A) != ndims(B)
177181
throw(DimensionMismatch("Elementwise addition only works if both tensors have the same indices, but A and B have different numbers ($(ndims(A)) and $(ndims(B))) of indices."))
178182
end
@@ -181,20 +185,33 @@ function _addtttensor(A::Array{V}, B::Array{V}; lefttensor=false, righttensor=fa
181185
offset3 = righttensor ? 0 : size(A, nd)
182186
localindices = fill(Colon(), nd - 2)
183187
C = zeros(V, offset1 + size(B, 1), size(A)[2:nd-1]..., offset3 + size(B, nd))
184-
C[1:size(A, 1), localindices..., 1:size(A, nd)] = A
185-
C[offset1+1:end, localindices..., offset3+1:end] = B
188+
C[1:size(A, 1), localindices..., 1:size(A, nd)] = factorA * A
189+
C[offset1+1:end, localindices..., offset3+1:end] = factorB * B
186190
return C
187191
end
188192

189193
@doc raw"""
190-
function add(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
194+
function add(
195+
lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V};
196+
factorlhs=one(V), factorrhs=one(V),
197+
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int)
198+
) where {V}
199+
200+
Addition of two tensor trains. If `C = add(A, B)`, then `C(v) ≈ A(v) + B(v)` at each index set `v`. Note that this function increases the bond dimension, i.e. ``\chi_{\text{result}} = \chi_1 + \chi_2`` if the original tensor trains had bond dimensions ``\chi_1`` and ``\chi_2``.
191201
192-
Addition of two tensor trains. If `c = add(a, b)`, then `c(v) ≈ a(v) + b(v)` at each index set `v`. Note that this function increases the bond dimension, i.e. ``\chi_{\text{result}} = \chi_1 + \chi_2`` if the original tensor trains had bond dimensions ``\chi_1`` and ``\chi_2``. In many cases, it is advisable to recompress/truncate the resulting tensor train afterwards.
202+
Arguments:
203+
- `lhs`, `rhs`: Tensor trains to be added.
204+
- `factorlhs`, `factorrhs`: Factors to multiply each tensor train by before addition.
205+
- `tolerance`, `maxbonddim`: Parameters to be used for the recompression step.
206+
207+
Returns:
208+
A new `TensorTrain` representing the function `factorlhs * lhs(v) + factorrhs * rhs(v)`.
193209
194210
See also: [`+`](@ref)
195211
"""
196212
function add(
197213
lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V};
214+
factorlhs=one(V), factorrhs=one(V),
198215
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int)
199216
) where {V}
200217
if length(lhs) != length(rhs)
@@ -203,7 +220,13 @@ function add(
203220
L = length(lhs)
204221
tt = tensortrain(
205222
[
206-
_addtttensor(lhs[ell], rhs[ell]; lefttensor=(ell==1), righttensor=(ell==L))
223+
_addtttensor(
224+
lhs[ell], rhs[ell];
225+
factorA=((ell == L) ? factorlhs : one(V)),
226+
factorB=((ell == L) ? factorrhs : one(V)),
227+
lefttensor=(ell==1),
228+
righttensor=(ell==L)
229+
)
207230
for ell in 1:L
208231
]
209232
)
@@ -212,12 +235,34 @@ function add(
212235
end
213236

214237
@doc raw"""
215-
function (+)(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
238+
function subtract(
239+
lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V};
240+
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int)
241+
)
242+
243+
Subtract two tensor trains `lhs` and `rhs`. See [`add`](@ref).
244+
"""
245+
function subtract(
246+
lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V};
247+
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int)
248+
) where {V}
249+
return add(lhs, rhs; factorrhs=-1 * one(V), tolerance, maxbonddim)
250+
end
216251

217-
Addition of two tensor trains. If `c = a + b`, then `c(v) ≈ a(v) + b(v)` at each index set `v`. Note that this function increases the bond dimension, i.e. ``\chi_{\text{result}} = \chi_1 + \chi_2`` if the original tensor trains had bond dimensions ``\chi_1`` and ``\chi_2``. In many cases, it is advisable to recompress/truncate the resulting tensor train afterwards.
252+
@doc raw"""
253+
function (+)(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
218254
219-
See also: [`add`](@ref)
255+
Addition of two tensor trains. If `c = a + b`, then `c(v) ≈ a(v) + b(v)` at each index set `v`. Note that this function increases the bond dimension, i.e. ``\chi_{\text{result}} = \chi_1 + \chi_2`` if the original tensor trains had bond dimensions ``\chi_1`` and ``\chi_2``. Can be combined with automatic recompression by calling [`add`](@ref).
220256
"""
221257
function Base.:+(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
222258
return add(lhs, rhs)
223259
end
260+
261+
@doc raw"""
262+
function (-)(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
263+
264+
Subtraction of two tensor trains. If `c = a - b`, then `c(v) ≈ a(v) - b(v)` at each index set `v`. Note that this function increases the bond dimension, i.e. ``\chi_{\text{result}} = \chi_1 + \chi_2`` if the original tensor trains had bond dimensions ``\chi_1`` and ``\chi_2``. Can be combined with automatic recompression by calling [`subtract`](@ref) (see documentation for [`add`](@ref)).
265+
"""
266+
function Base.:-(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
267+
return subtract(lhs, rhs)
268+
end

test/test_tensortrain.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ end
140140
tt1div = tt1mul / 3.2
141141
@test tt1div.(indices) tt1.(indices) ./ 2.0
142142

143+
tt1sub = tt1 - tt1div
144+
@test tt1sub.(indices) tt1.(indices) ./ 2.0
145+
143146
ttshort = TCI.TensorTrain{T,3}([randn(T, linkdims[n], localdims[n], linkdims[n+1]) for n in 1:L-1])
144147
@test_throws DimensionMismatch TCI.add(tt1, ttshort)
145148

0 commit comments

Comments
 (0)