Skip to content

Commit 7966c2c

Browse files
authored
Merge pull request #13 from tensor4all/8-tto-tto-contraction-with-zip-up-algorithm
Implement contract_zipup using SVD or LU
2 parents f938889 + 32d6f40 commit 7966c2c

File tree

2 files changed

+114
-1
lines changed

2 files changed

+114
-1
lines changed

src/contraction.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,56 @@ function contract_TCI(
399399
)
400400
end
401401

402+
"""
403+
See SVD version:
404+
https://tensornetwork.org/mps/algorithms/zip_up_mpo/
405+
"""
406+
function contract_zipup(
407+
A::TensorTrain{ValueType,4},
408+
B::TensorTrain{ValueType,4};
409+
tolerance::Float64=1e-12,
410+
method::Symbol=:SVD, # :SVD, :LU
411+
maxbonddim::Int=typemax(Int)
412+
) where {ValueType}
413+
if length(A) != length(B)
414+
throw(ArgumentError("Cannot contract tensor trains with different length."))
415+
end
416+
R::Array{ValueType,3} = ones(ValueType, 1, 1, 1)
417+
418+
sitetensors = Vector{Array{ValueType,4}}(undef, length(A))
419+
for n in 1:length(A)
420+
# R: (link_ab, link_an, link_bn)
421+
# A[n]: (link_an, s_n, s_n', link_anp1)
422+
RA = _contract(R, A[n], (2,), (1,))
423+
424+
# RA[n]: (link_ab, link_bn, s_n, s_n' link_anp1)
425+
# B[n]: (link_bn, s_n', s_n'', link_bnp1)
426+
# C: (link_ab, s_n, link_anp1, s_n'', link_bnp1)
427+
# => (link_ab, s_n, s_n'', link_anp1, link_bnp1)
428+
C = permutedims(_contract(RA, B[n], (2, 4), (1, 2)), (1, 2, 4, 3, 5))
429+
if n == length(A)
430+
sitetensors[n] = reshape(C, size(C)[1:3]..., 1)
431+
break
432+
end
433+
434+
# Cmat: (link_ab * s_n * s_n'', link_anp1 * link_bnp1)
435+
436+
#lu = rrlu(Cmat; reltol, abstol, leftorthogonal=true)
437+
left, right, newbonddim = _factorize(
438+
reshape(C, prod(size(C)[1:3]), prod(size(C)[4:5])),
439+
method; tolerance, maxbonddim
440+
)
441+
442+
# U: (link_ab, s_n, s_n'', link_ab_new)
443+
sitetensors[n] = reshape(left, size(C)[1:3]..., newbonddim)
444+
445+
# R: (link_ab_new, link_an, link_bn)
446+
R = reshape(right, newbonddim, size(C)[4:5]...)
447+
end
448+
449+
return TensorTrain{ValueType,4}(sitetensors)
450+
end
451+
402452
"""
403453
function contract(
404454
A::TensorTrain{ValueType,4},
@@ -415,13 +465,15 @@ Contract two tensor trains `A` and `B`.
415465
Currently, two implementations are available:
416466
1. `algorithm=:TCI` constructs a new TCI that fits the contraction of `A` and `B`.
417467
2. `algorithm=:naive` uses a naive tensor contraction and subsequent SVD recompression of the tensor train.
468+
2. `algorithm=:zipup` uses a naive tensor contraction with on-the-fly LU decomposition.
418469
419470
Arguments:
420471
- `A` and `B` are the tensor trains to be contracted.
421472
- `algorithm` chooses the algorithm used to evaluate the contraction.
422473
- `tolerance` is the tolerance of the TCI or SVD recompression.
423474
- `maxbonddim` sets the maximum bond dimension of the resulting tensor train.
424475
- `f` is a function to be applied elementwise to the result. This option is only available with `algorithm=:TCI`.
476+
- `method` chooses the method used for the factorization in the `algorithm=:zipup` case (`:SVD` or `:LU`).
425477
- `kwargs...` are forwarded to [`crossinterpolate2`](@ref) if `algorithm=:TCI`.
426478
"""
427479
function contract(
@@ -440,6 +492,11 @@ function contract(
440492
error("Naive contraction implementation cannot contract matrix product with a function. Use algorithm=:TCI instead.")
441493
end
442494
return contract_naive(A, B; tolerance=tolerance, maxbonddim=maxbonddim)
495+
elseif algorithm === :zipup
496+
if f !== nothing
497+
error("Zipup contraction implementation cannot contract matrix product with a function. Use algorithm=:TCI instead.")
498+
end
499+
return contract_zipup(A, B; tolerance, maxbonddim)
443500
else
444501
throw(ArgumentError("Unknown algorithm $algorithm."))
445502
end

test/test_contraction.jl

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function _tomat(tto::TensorTrain{T,4}) where {T}
1515
return mat
1616
end
1717

18-
function _tovec(tt::TensorTrain{T, 3}) where {T}
18+
function _tovec(tt::TensorTrain{T,3}) where {T}
1919
sitedims = TCI.sitedims(tt)
2020
localdims1 = [s[1] for s in sitedims]
2121
return evaluate.(Ref(tt), CartesianIndices(Tuple(localdims1))[:])
@@ -28,7 +28,45 @@ end
2828
@test vec(reshape(permutedims(a, (2, 1, 3)), 3, :) * reshape(permutedims(b, (1, 3, 2)), :, 5)) vec(ab)
2929
end
3030

31+
function _gen_testdata_TTO_TTO()
32+
N = 4
33+
bonddims_a = [1, 2, 3, 2, 1]
34+
bonddims_b = [1, 2, 3, 2, 1]
35+
localdims1 = [2, 2, 2, 2]
36+
localdims2 = [3, 3, 3, 3]
37+
localdims3 = [2, 2, 2, 2]
38+
39+
a = TensorTrain{ComplexF64,4}([
40+
rand(ComplexF64, bonddims_a[n], localdims1[n], localdims2[n], bonddims_a[n+1])
41+
for n = 1:N
42+
])
43+
b = TensorTrain{ComplexF64,4}([
44+
rand(ComplexF64, bonddims_b[n], localdims2[n], localdims3[n], bonddims_b[n+1])
45+
for n = 1:N
46+
])
47+
return N, a, b, localdims1, localdims2, localdims3
48+
end
49+
50+
function _gen_testdata_TTO_TTS()
51+
N = 4
52+
bonddims_a = [1, 2, 3, 2, 1]
53+
bonddims_b = [1, 2, 3, 2, 1]
54+
localdims1 = [3, 3, 3, 3]
55+
localdims2 = [3, 3, 3, 3]
56+
57+
a = TensorTrain{ComplexF64,4}([
58+
rand(ComplexF64, bonddims_a[n], localdims1[n], localdims2[n], bonddims_a[n+1])
59+
for n = 1:N
60+
])
61+
b = TensorTrain{ComplexF64,3}([
62+
rand(ComplexF64, bonddims_b[n], localdims2[n], bonddims_b[n+1])
63+
for n = 1:N
64+
])
65+
return N, a, b, localdims1, localdims2
66+
end
67+
3168
@testset "MPO-MPO contraction" for f in [nothing, x -> 2 * x], algorithm in [:TCI, :naive]
69+
#==
3270
N = 4
3371
bonddims_a = [1, 2, 3, 2, 1]
3472
bonddims_b = [1, 2, 3, 2, 1]
@@ -44,6 +82,8 @@ end
4482
rand(ComplexF64, bonddims_b[n], localdims2[n], localdims3[n], bonddims_b[n+1])
4583
for n = 1:N
4684
])
85+
==#
86+
N, a, b, localdims1, localdims2, localdims3 = _gen_testdata_TTO_TTO()
4787

4888
if f !== nothing && algorithm === :naive
4989
@test_throws ErrorException contract(a, b; f=f, algorithm=algorithm)
@@ -59,6 +99,7 @@ end
5999
end
60100

61101
@testset "MPO-MPS contraction" for f in [nothing, x -> 2 * x], algorithm in [:TCI, :naive]
102+
#==
62103
N = 4
63104
bonddims_a = [1, 2, 3, 2, 1]
64105
bonddims_b = [1, 2, 3, 2, 1]
@@ -73,6 +114,8 @@ end
73114
rand(ComplexF64, bonddims_b[n], localdims2[n], bonddims_b[n+1])
74115
for n = 1:N
75116
])
117+
==#
118+
N, a, b, localdims1, localdims2 = _gen_testdata_TTO_TTS()
76119

77120
if f !== nothing && algorithm === :naive
78121
@test_throws ErrorException contract(a, b; f=f, algorithm=algorithm)
@@ -90,3 +133,16 @@ end
90133
end
91134
end
92135
end
136+
137+
138+
@testset "MPO-MPO contraction (zipup)" for method in [:SVD, :LU]
139+
N, a, b, localdims1, localdims2, localdims3 = _gen_testdata_TTO_TTO()
140+
ab = contract(a, b; algorithm=:zipup, method=method)
141+
@test _tomat(ab) _tomat(a) * _tomat(b)
142+
end
143+
144+
@testset "MPO-MPS contraction (zipup)" for method in [:SVD, :LU]
145+
N, a, b, localdims1, localdims2 = _gen_testdata_TTO_TTS()
146+
ab = contract(a, b; algorithm=:zipup, method=method)
147+
@test _tovec(ab) _tomat(a) * _tovec(b)
148+
end

0 commit comments

Comments
 (0)