Skip to content

Commit 2439e72

Browse files
committed
fixed tests
1 parent d7f113b commit 2439e72

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

src/contraction.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ function contract_naive(
324324
tolerance=0.0, maxbonddim=typemax(Int)
325325
)::TensorTrain{T,4} where {T}
326326
if obj.f isa Function
327-
error("Cannot contract matrix product with a function.")
327+
error("Naive contraction implementation cannot contract matrix product with a function. Use algorithm=:TCI instead.")
328328
end
329329

330330
a, b = obj.mpo
@@ -436,6 +436,9 @@ function contract(
436436
if algorithm === :TCI
437437
return contract_TCI(A, B; tolerance=tolerance, maxbonddim=maxbonddim, f=f, kwargs...)
438438
elseif algorithm === :naive
439+
if f !== nothing
440+
error("Naive contraction implementation cannot contract matrix product with a function. Use algorithm=:TCI instead.")
441+
end
439442
return contract_naive(A, B; tolerance=tolerance, maxbonddim=maxbonddim)
440443
else
441444
throw(ArgumentError("Unknown algorithm $algorithm."))

test/test_contraction.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222
@test vec(reshape(permutedims(a, (2, 1, 3)), 3, :) * reshape(permutedims(b, (1, 3, 2)), :, 5)) vec(ab)
2323
end
2424

25-
@testset "MPO-MPO contraction" for f in [nothing, x -> 2 * x], algorithm in ["TCI", "naive"]
25+
@testset "MPO-MPO contraction" for f in [nothing, x -> 2 * x], algorithm in [:TCI, :naive]
2626
N = 4
2727
bonddims_a = [1, 2, 3, 2, 1]
2828
bonddims_b = [1, 2, 3, 2, 1]
@@ -39,7 +39,9 @@ end
3939
for n = 1:N
4040
])
4141

42-
if f === nothing || algorithm != "naive"
42+
if f !== nothing && algorithm === :naive
43+
@test_throws ErrorException contract(a, b; f=f, algorithm=algorithm)
44+
else
4345
ab = contract(a, b; f=f, algorithm=algorithm)
4446
@test sitedims(ab) == [[localdims1[i], localdims3[i]] for i = 1:N]
4547
if f === nothing

0 commit comments

Comments
 (0)