Skip to content

Commit 37dbd7c

Browse files
committed
Refactored
1 parent a6b212b commit 37dbd7c

File tree

2 files changed

+11
-27
lines changed

2 files changed

+11
-27
lines changed

src/contraction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ _localdims(obj::TensorTrain{<:Any,4}, n::Int)::Tuple{Int,Int} =
6262
_localdims(obj::Contraction{<:Any}, n::Int)::Tuple{Int,Int} =
6363
(size(obj.mpo[1][n], 2), size(obj.mpo[2][n], 3))
6464

65-
_getindex(x, indices) = ntuple(i->x[indices[i]], length(indices))
65+
_getindex(x, indices) = ntuple(i -> x[indices[i]], length(indices))
6666

6767
function _contract(
6868
a::AbstractArray{T1,N1},

test/test_contraction.jl

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222
@test vec(reshape(permutedims(a, (2, 1, 3)), 3, :) * reshape(permutedims(b, (1, 3, 2)), :, 5)) vec(ab)
2323
end
2424

25-
@testset "MPO-MPO naive contraction" begin
25+
@testset "MPO-MPO contraction" for f in [nothing, x -> 2 * x], algorithm in ["TCI", "naive"]
2626
N = 4
2727
bonddims_a = [1, 2, 3, 2, 1]
2828
bonddims_b = [1, 2, 3, 2, 1]
@@ -39,29 +39,13 @@ end
3939
for n = 1:N
4040
])
4141

42-
ab = contract(a, b; algorithm="naive")
43-
44-
@test _tomat(ab) _tomat(a) * _tomat(b)
45-
end
46-
47-
@testset "MPO-MPO contraction" for f in [x -> x, x -> 2 * x]
48-
N = 4
49-
bonddims_a = [1, 2, 3, 2, 1]
50-
bonddims_b = [1, 2, 3, 2, 1]
51-
localdims1 = [2, 2, 2, 2]
52-
localdims2 = [3, 3, 3, 3]
53-
localdims3 = [2, 2, 2, 2]
54-
55-
a = TensorTrain{ComplexF64,4}([
56-
rand(ComplexF64, bonddims_a[n], localdims1[n], localdims2[n], bonddims_a[n+1])
57-
for n = 1:N
58-
])
59-
b = TensorTrain{ComplexF64,4}([
60-
rand(ComplexF64, bonddims_b[n], localdims2[n], localdims3[n], bonddims_b[n+1])
61-
for n = 1:N
62-
])
63-
64-
ab = contract(a, b; f = f, algorithm="TCI")
65-
@test sitedims(ab) == [[localdims1[i], localdims3[i]] for i = 1:N]
66-
@test _tomat(ab) f.(_tomat(a) * _tomat(b))
42+
if f === nothing || algorithm != "naive"
43+
ab = contract(a, b; f=f, algorithm=algorithm)
44+
@test sitedims(ab) == [[localdims1[i], localdims3[i]] for i = 1:N]
45+
if f === nothing
46+
@test _tomat(ab) _tomat(a) * _tomat(b)
47+
else
48+
@test _tomat(ab) f.(_tomat(a) * _tomat(b))
49+
end
50+
end
6751
end

0 commit comments

Comments
 (0)