Skip to content

Commit 60b6bee

Browse files
committed
Merged
2 parents b9d5d62 + 158b599 commit 60b6bee

12 files changed

+120
-26
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorCrossInterpolation"
22
uuid = "b261b2ec-6378-4871-b32e-9173bb050604"
33
authors = ["Ritter.Marc <[email protected]>, Hiroshi Shinaoka <[email protected]> and contributors"]
4-
version = "0.9.7"
4+
version = "0.9.12"
55

66
[deps]
77
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"

src/batcheval.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ struct ThreadedBatchEvaluator{T} <: BatchEvaluator{T}
9595
end
9696
end
9797

98+
function (obj::ThreadedBatchEvaluator{T})(indexset::Vector{Int})::T where {T}
99+
return obj.f(indexset)
100+
end
98101

99102
# Batch evaluation (loop over all index sets)
100103
function (obj::ThreadedBatchEvaluator{T})(leftindexset::Vector{Vector{Int}}, rightindexset::Vector{Vector{Int}}, ::Val{M})::Array{T,M + 2} where {T,M}

src/cachedfunction.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ function _batcheval_imp_default(f::CachedFunction{V,K},
9090
rightindexset::AbstractVector{MultiIndex},
9191
::Val{M}
9292
)::Array{V,M + 2} where {V,K,M}
93+
if length(leftindexset) * length(rightindexset) == 0
94+
return Array{V,M + 2}(undef, ntuple(d -> 0, M + 2)...)
95+
end
9396
nl = length(first(leftindexset))
9497
nr = length(first(rightindexset))
9598
L = length(f.localdims)
@@ -113,6 +116,9 @@ function _batcheval_imp_for_batchevaluator(f::CachedFunction{V,K},
113116
rightindexset::AbstractVector{MultiIndex},
114117
::Val{M}
115118
)::Array{V,M + 2} where {V,K,M}
119+
if length(leftindexset) * length(rightindexset) == 0
120+
return Array{V,M + 2}(undef, ntuple(d -> 0, M + 2)...)
121+
end
116122
nl = length(first(leftindexset))
117123
nr = length(first(rightindexset))
118124
L = length(f.localdims)

src/contraction.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -487,14 +487,14 @@ end
487487

488488
"""
489489
function contract(
490-
A::TensorTrain{ValueType,4},
491-
B::TensorTrain{ValueType,4};
490+
A::TensorTrain{V1,4},
491+
B::TensorTrain{V2,4};
492492
algorithm::Symbol=:TCI,
493493
tolerance::Float64=1e-12,
494494
maxbonddim::Int=typemax(Int),
495495
f::Union{Nothing,Function}=nothing,
496496
kwargs...
497-
) where {ValueType}
497+
) where {V1,V2}
498498
499499
Contract two tensor trains `A` and `B`.
500500
@@ -513,45 +513,48 @@ Arguments:
513513
- `kwargs...` are forwarded to [`crossinterpolate2`](@ref) if `algorithm=:TCI`.
514514
"""
515515
function contract(
516-
A::TensorTrain{ValueType,4},
517-
B::TensorTrain{ValueType,4};
516+
A::TensorTrain{V1,4},
517+
B::TensorTrain{V2,4};
518518
algorithm::Symbol=:TCI,
519519
tolerance::Float64=1e-12,
520520
maxbonddim::Int=typemax(Int),
521521
f::Union{Nothing,Function}=nothing,
522522
kwargs...
523-
) where {ValueType}
523+
)::TensorTrain{promote_type(V1,V2),4} where {V1,V2}
524+
Vres = promote_type(V1, V2)
525+
A_ = TensorTrain{Vres,4}(A)
526+
B_ = TensorTrain{Vres,4}(B)
524527
if algorithm === :TCI
525-
return contract_TCI(A, B; tolerance=tolerance, maxbonddim=maxbonddim, f=f, kwargs...)
528+
return contract_TCI(A_, B_; tolerance=tolerance, maxbonddim=maxbonddim, f=f, kwargs...)
526529
elseif algorithm === :naive
527530
if f !== nothing
528531
error("Naive contraction implementation cannot contract matrix product with a function. Use algorithm=:TCI instead.")
529532
end
530-
return contract_naive(A, B; tolerance=tolerance, maxbonddim=maxbonddim)
533+
return contract_naive(A_, B_; tolerance=tolerance, maxbonddim=maxbonddim)
531534
elseif algorithm === :zipup
532535
if f !== nothing
533536
error("Zipup contraction implementation cannot contract matrix product with a function. Use algorithm=:TCI instead.")
534537
end
535-
return contract_zipup(A, B; tolerance, maxbonddim)
538+
return contract_zipup(A_, B_; tolerance, maxbonddim)
536539
else
537540
throw(ArgumentError("Unknown algorithm $algorithm."))
538541
end
539542
end
540543

541544
function contract(
542545
A::Union{TensorCI1{V},TensorCI2{V},TensorTrain{V,3}},
543-
B::TensorTrain{V,4};
546+
B::TensorTrain{V2,4};
544547
kwargs...
545-
)::TensorTrain{V,3} where {V}
548+
)::TensorTrain{promote_type(V,V2),3} where {V,V2}
546549
tt = contract(TensorTrain{4}(A, [(1, s...) for s in sitedims(A)]), B; kwargs...)
547550
return TensorTrain{3}(tt, prod.(sitedims(tt)))
548551
end
549552

550553
function contract(
551554
A::TensorTrain{V,4},
552-
B::Union{TensorCI1{V},TensorCI2{V},TensorTrain{V,3}};
555+
B::Union{TensorCI1{V2},TensorCI2{V2},TensorTrain{V2,3}};
553556
kwargs...
554-
)::TensorTrain{V,3} where {V}
557+
)::TensorTrain{promote_type(V,V2),3} where {V,V2}
555558
tt = contract(A, TensorTrain{4}(B, [(s..., 1) for s in sitedims(B)]); kwargs...)
556559
return TensorTrain{3}(tt, prod.(sitedims(tt)))
557560
end

src/globalsearch.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ function _floatingzone(
8080
)
8181
err = vec(abs.(exactdata .- prediction))
8282
pivot[ipos] = argmax(err)
83-
maxerror = maximum(err)
83+
# In RHS, we compare the maximum of the error vector with the current maxerror
84+
# to make sure that the error does not decrease even if maxerror is close to machine precision.
85+
maxerror = max(maximum(err), maxerror)
8486
end
8587

8688
if maxerror == prev_maxerror || maxerror > earlystoptol # early stop

src/tensorci2.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ function TensorCI2{ValueType}(
4747
) where {F,ValueType,N}
4848
tci = TensorCI2{ValueType}(localdims)
4949
addglobalpivots!(tci, initialpivots)
50-
tci.maxsamplevalue = maximum(abs, func.(initialpivots))
50+
tci.maxsamplevalue = maximum(abs, (func(x) for x in initialpivots))
5151
abs(tci.maxsamplevalue) > 0.0 || error("maxsamplevalue is zero!")
5252
invalidatesitetensors!(tci)
5353
return tci
@@ -431,7 +431,8 @@ function makecanonical!(
431431
abstol::Float64=0.0,
432432
maxbonddim::Int=typemax(Int)
433433
) where {F,ValueType}
434-
sweep1site!(tci, f, :forward; reltol, abstol, maxbonddim, updatetensors=false)
434+
# The first half-sweep is performed exactly without compression.
435+
sweep1site!(tci, f, :forward; reltol=0.0, abstol=0.0, maxbonddim=typemax(Int), updatetensors=false)
435436
sweep1site!(tci, f, :backward; reltol, abstol, maxbonddim, updatetensors=false)
436437
sweep1site!(tci, f, :forward; reltol, abstol, maxbonddim, updatetensors=true)
437438
end

src/tensortrain.jl

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ mutable struct TensorTrain{ValueType,N} <: AbstractTensorTrain{ValueType}
3030
end
3131
end
3232

33+
function Base.show(io::IO, obj::TensorTrain{V,N}) where {V,N}
34+
print(
35+
io,
36+
"$(typeof(obj)) of rank $(maximum(linkdims(obj)))"
37+
)
38+
end
39+
40+
function TensorTrain{V2,N}(tt::TensorTrain{V})::TensorTrain{V2,N} where {V,V2,N}
41+
return TensorTrain{V2,N}(Array{V2}.(sitetensors(tt)))
42+
end
43+
3344
"""
3445
function TensorTrain(sitetensors::Vector{Array{V, 3}}) where {V}
3546
@@ -53,7 +64,7 @@ function TensorTrain(tci::AbstractTensorTrain{V})::TensorTrain{V,3} where {V}
5364
end
5465

5566
"""
56-
function TensorTrain{N}(tci::AbstractTensorTrain{V}) where {V,N}
67+
function TensorTrain{V2,N}(tci::AbstractTensorTrain{V}) where {V,V2,N}
5768
5869
Convert a tensor-train-like object into a tensor train.
5970
@@ -62,15 +73,15 @@ Arguments:
6273
- `localdims`: a vector of local dimensions for each tensor in the tensor train. A each element
6374
of `localdims` should be an array-like object of `N-2` integers.
6475
"""
65-
function TensorTrain{V,N}(tt::AbstractTensorTrain{V}, localdims)::TensorTrain{V,N} where {V,N}
76+
function TensorTrain{V2,N}(tt::AbstractTensorTrain{V}, localdims)::TensorTrain{V2,N} where {V,V2,N}
6677
for d in localdims
6778
length(d) == N - 2 || error("Each element of localdims be a list of N-2 integers.")
6879
end
6980
for n in 1:length(tt)
7081
prod(size(tt[n])[2:end-1]) == prod(localdims[n]) || error("The local dimensions at n=$n must match the tensor sizes.")
7182
end
72-
return TensorTrain{V,N}(
73-
[reshape(t, size(t, 1), localdims[n]..., size(t)[end]) for (n, t) in enumerate(sitetensors(tt))])
83+
return TensorTrain{V2,N}(
84+
[reshape(Array{V2}(t), size(t, 1), localdims[n]..., size(t)[end]) for (n, t) in enumerate(sitetensors(tt))])
7485
end
7586

7687
function TensorTrain{N}(tt::AbstractTensorTrain{V}, localdims)::TensorTrain{V,N} where {V,N}
@@ -262,4 +273,20 @@ end
262273
function (obj::TensorTrainFit{ValueType})(x::Vector{ValueType}) where {ValueType}
263274
tensors = to_tensors(obj, x)
264275
return sum((abs2(_evaluate(tensors, indexset) - obj.values[i]) for (i, indexset) in enumerate(obj.indexsets)))
276+
end
277+
278+
279+
function fulltensor(obj::TensorTrain{T,N})::Array{T} where {T,N}
280+
sitedims_ = sitedims(obj)
281+
localdims = collect(prod.(sitedims_))
282+
result::Matrix{T} = reshape(obj.sitetensors[1], localdims[1], :)
283+
leftdim = localdims[1]
284+
for l in 2:length(obj)
285+
nextmatrix = reshape(
286+
obj.sitetensors[l], size(obj.sitetensors[l], 1), localdims[l] * size(obj.sitetensors[l])[end])
287+
leftdim *= localdims[l]
288+
result = reshape(result * nextmatrix, leftdim, size(obj.sitetensors[l])[end])
289+
end
290+
returnsize = collect(Iterators.flatten(sitedims_))
291+
return reshape(result, returnsize...)
265292
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ include("test_tensortrain.jl")
2020
include("test_conversion.jl")
2121
include("test_contraction.jl")
2222
include("test_integration.jl")
23-
include("test_globalsearch.jl")
23+
include("test_globalsearch.jl")

test/test_batcheval.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,21 @@ end
6565

6666
@test result ref
6767
end
68-
end
68+
69+
@testset "ThreadedBatchEvaluator (from Matsuura)" begin
70+
function f(x)
71+
sleep(1e-3)
72+
return sum(x)
73+
end
74+
75+
L = 20
76+
localdims = fill(2, L)
77+
parf = TCI.ThreadedBatchEvaluator{Float64}(f, localdims)
78+
79+
tci, ranks, errors = TCI.crossinterpolate2(Float64, parf, localdims)
80+
81+
tci_ref, ranks_ref, errors_ref = TCI.crossinterpolate2(Float64, f, localdims)
82+
83+
@test TCI.fulltensor(TCI.TensorTrain(tci)) TCI.fulltensor(TCI.TensorTrain(tci_ref))
84+
end
85+
end

test/test_globalsearch.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,6 @@ import QuanticsGrids as QD
3232
pivoterrors = TCI.estimatetrueerror(TCI.TensorTrain(tci), f)
3333

3434
errors = [e for (_, e) in pivoterrors]
35-
@test all([abs(f(p) - tci(p)) for (p, _) in pivoterrors] .== errors)
35+
@test [abs(f(p) - tci(p)) for (p, _) in pivoterrors] errors
3636
@test all(errors[1:end-1] .>= errors[2:end]) # check if errors are sorted in descending order
3737
end

0 commit comments

Comments
 (0)