Skip to content

Commit e9a191c

Browse files
authored
Merge pull request #26 from tensor4all/25-bug-in-threadbatchevaluator
Fix bug in ThreadBatchEvaluator
2 parents 2627d0e + 82ed982 commit e9a191c

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

src/batcheval.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ struct ThreadedBatchEvaluator{T} <: BatchEvaluator{T}
9595
end
9696
end
9797

98+
function (obj::ThreadedBatchEvaluator{T})(indexset::Vector{Int})::T where {T}
99+
return obj.f(indexset)
100+
end
98101

99102
# Batch evaluation (loop over all index sets)
100103
function (obj::ThreadedBatchEvaluator{T})(leftindexset::Vector{Vector{Int}}, rightindexset::Vector{Vector{Int}}, ::Val{M})::Array{T,M + 2} where {T,M}

test/test_batcheval.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,21 @@ end
6565

6666
@test result ref
6767
end
68-
end
68+
69+
@testset "ThreadedBatchEvaluator (from Matsuura)" begin
70+
function f(x)
71+
sleep(1e-3)
72+
return sum(x)
73+
end
74+
75+
L = 20
76+
localdims = fill(2, L)
77+
parf = TCI.ThreadedBatchEvaluator{Float64}(f, localdims)
78+
79+
tci, ranks, errors = TCI.crossinterpolate2(Float64, parf, localdims)
80+
81+
tci_ref, ranks_ref, errors_ref = TCI.crossinterpolate2(Float64, f, localdims)
82+
83+
@test TCI.fulltensor(TCI.TensorTrain(tci)) TCI.fulltensor(TCI.TensorTrain(tci_ref))
84+
end
85+
end

0 commit comments

Comments
 (0)