Skip to content

Commit 91ff638

Browse files
committed
added SVD recompression step to addition
1 parent f9d406a commit 91ff638

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

src/abstracttensortrain.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,17 +193,22 @@ Addition of two tensor trains. If `c = add(a, b)`, then `c(v) ≈ a(v) + b(v)` a
193193
194194
See also: [`+`](@ref)
195195
"""
196-
function add(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
196+
function add(
197+
lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V};
198+
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int)
199+
) where {V}
197200
if length(lhs) != length(rhs)
198201
throw(DimensionMismatch("Two tensor trains with different length ($(length(lhs)) and $(length(rhs))) cannot be added elementwise."))
199202
end
200203
L = length(lhs)
201-
return tensortrain(
204+
tt = tensortrain(
202205
[
203206
_addtttensor(lhs[ell], rhs[ell]; lefttensor=(ell==1), righttensor=(ell==L))
204207
for ell in 1:L
205208
]
206209
)
210+
compress!(tt, :SVD; tolerance, maxbonddim)
211+
return tt
207212
end
208213

209214
@doc raw"""

src/tensortrain.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ function compress!(
9595
tt::TensorTrain{V, N},
9696
method::Symbol=:LU;
9797
tolerance::Float64=1e-12,
98-
maxbonddim=typemax(Int)
98+
maxbonddim::Int=typemax(Int)
9999
) where {V, N}
100100
for ell in 1:length(tt)-1
101101
shapel = size(tt.sitetensors[ell])

0 commit comments

Comments
 (0)