Skip to content

Commit a92a528

Browse files
committed
added MPO-MPS contractions
1 parent 2439e72 commit a92a528

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

src/contraction.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,3 +444,28 @@ function contract(
444444
throw(ArgumentError("Unknown algorithm $algorithm."))
445445
end
446446
end
447+
448+
function _promoteMPStoMPO(
449+
tt::AbstractTensorTrain{V},
450+
unfusedlocalshape::Union{AbstractVector{Int}, Tuple}
451+
)::TensorTrain{V, 4} where {V}
452+
return TensorTrain{V, 4}(_reshape_splitsites.(sitetensors(tt), Ref(unfusedlocalshape)))
453+
end
454+
455+
function contract(
456+
A::Union{TensorCI1{V}, TensorCI2{V}, TensorTrain{V, 3}},
457+
B::TensorTrain{V, 4};
458+
kwargs...
459+
)::TensorTrain{V, 3} where {V}
460+
tt = contract(_promoteMPStoMPO(A, (1, sitedim(A, 1)...)), B; kwargs...)
461+
return TensorTrain{V, 3}([T for (T, shape) in _reshape_fusesites.(sitetensors(tt))])
462+
end
463+
464+
function contract(
465+
A::TensorTrain{V, 4},
466+
B::Union{TensorCI1{V}, TensorCI2{V}, TensorTrain{V, 3}};
467+
kwargs...
468+
)::TensorTrain{V, 3} where {V}
469+
tt = contract(A, _promoteMPStoMPO(B, (sitedim(B, 1)..., 1)); kwargs...)
470+
return TensorTrain{V, 3}([T for (T, shape) in _reshape_fusesites.(sitetensors(tt))])
471+
end

test/test_contraction.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ function _tomat(tto::TensorTrain{T,4}) where {T}
1515
return mat
1616
end
1717

18+
function _tovec(tt::TensorTrain{T, 3}) where {T}
19+
sitedims = TCI.sitedims(tt)
20+
localdims1 = [s[1] for s in sitedims]
21+
return evaluate.(Ref(tt), CartesianIndices(Tuple(localdims1))[:])
22+
end
23+
1824
@testset "_contract" begin
1925
a = rand(2, 3, 4)
2026
b = rand(2, 5, 4)
@@ -51,3 +57,36 @@ end
5157
end
5258
end
5359
end
60+
61+
@testset "MPO-MPS contraction" for f in [nothing, x -> 2 * x], algorithm in [:TCI, :naive]
62+
N = 4
63+
bonddims_a = [1, 2, 3, 2, 1]
64+
bonddims_b = [1, 2, 3, 2, 1]
65+
localdims1 = [3, 3, 3, 3]
66+
localdims2 = [3, 3, 3, 3]
67+
68+
a = TensorTrain{ComplexF64,4}([
69+
rand(ComplexF64, bonddims_a[n], localdims1[n], localdims2[n], bonddims_a[n+1])
70+
for n = 1:N
71+
])
72+
b = TensorTrain{ComplexF64,3}([
73+
rand(ComplexF64, bonddims_b[n], localdims2[n], bonddims_b[n+1])
74+
for n = 1:N
75+
])
76+
77+
if f !== nothing && algorithm === :naive
78+
@test_throws ErrorException contract(a, b; f=f, algorithm=algorithm)
79+
@test_throws ErrorException contract(b, a; f=f, algorithm=algorithm)
80+
else
81+
ab = contract(a, b; f=f, algorithm=algorithm)
82+
ba = contract(b, a; f=f, algorithm=algorithm)
83+
@test sitedims(ab) == [[localdims1[i]] for i = 1:N]
84+
if f === nothing
85+
@test _tovec(ab) _tomat(a) * _tovec(b)
86+
@test transpose(_tovec(ba)) transpose(_tovec(b)) * _tomat(a)
87+
else
88+
@test _tovec(ab) f.(_tomat(a) * _tovec(b))
89+
@test transpose(_tovec(ba)) f.(transpose(_tovec(b)) * _tomat(a))
90+
end
91+
end
92+
end

0 commit comments

Comments
 (0)