@@ -399,6 +399,31 @@ function contract_TCI(
399
399
)
400
400
end
401
401
402
+ """
403
+ function contract(
404
+ A::TensorTrain{ValueType,4},
405
+ B::TensorTrain{ValueType,4};
406
+ algorithm::Symbol=:TCI,
407
+ tolerance::Float64=1e-12,
408
+ maxbonddim::Int=typemax(Int),
409
+ f::Union{Nothing,Function}=nothing,
410
+ kwargs...
411
+ ) where {ValueType}
412
+
413
+ Contract two tensor trains `A` and `B`.
414
+
415
+ Currently, two implementations are available:
416
+ 1. `algorithm=:TCI` constructs a new TCI that fits the contraction of `A` and `B`.
417
+ 2. `algorithm=:naive` uses a naive tensor contraction and subsequent SVD recompression of the tensor train.
418
+
419
+ Arguments:
420
+ - `A` and `B` are the tensor trains to be contracted.
421
+ - `algorithm` chooses the algorithm used to evaluate the contraction.
422
+ - `tolerance` is the tolerance of the TCI or SVD recompression.
423
+ - `maxbonddim` sets the maximum bond dimension of the resulting tensor train.
424
+ - `f` is a function to be applied elementwise to the result. This option is only available with `algorithm=:TCI`.
425
+ - `kwargs...` are forwarded to [`crossinterpolate2`](@ref) if `algorithm=:TCI`.
426
+ """
402
427
function contract (
403
428
A:: TensorTrain{ValueType,4} ,
404
429
B:: TensorTrain{ValueType,4} ;
@@ -411,7 +436,7 @@ function contract(
411
436
if algorithm === :TCI
412
437
return contract_TCI (A, B; tolerance= tolerance, maxbonddim= maxbonddim, f= f, kwargs... )
413
438
elseif algorithm === :naive
414
- return contract_naive (A, B)
439
+ return contract_naive (A, B; tolerance = tolerance, maxbonddim = maxbonddim )
415
440
else
416
441
throw (ArgumentError (" Unknown algorithm $algorithm ." ))
417
442
end
0 commit comments