Skip to content

Commit 4f35183

Browse files
authored
Merge branch 'main' into 17-improvements-for-compress-function-using-svd
2 parents b9d5d62 + 97d1248 commit 4f35183

File tree

9 files changed

+88
-23
lines changed

9 files changed

+88
-23
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorCrossInterpolation"
22
uuid = "b261b2ec-6378-4871-b32e-9173bb050604"
33
authors = ["Ritter.Marc <[email protected]>, Hiroshi Shinaoka <[email protected]> and contributors"]
4-
version = "0.9.7"
4+
version = "0.9.10"
55

66
[deps]
77
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"

src/batcheval.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ struct ThreadedBatchEvaluator{T} <: BatchEvaluator{T}
9595
end
9696
end
9797

98+
function (obj::ThreadedBatchEvaluator{T})(indexset::Vector{Int})::T where {T}
99+
return obj.f(indexset)
100+
end
98101

99102
# Batch evaluation (loop over all index sets)
100103
function (obj::ThreadedBatchEvaluator{T})(leftindexset::Vector{Vector{Int}}, rightindexset::Vector{Vector{Int}}, ::Val{M})::Array{T,M + 2} where {T,M}

src/cachedfunction.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ function _batcheval_imp_default(f::CachedFunction{V,K},
9090
rightindexset::AbstractVector{MultiIndex},
9191
::Val{M}
9292
)::Array{V,M + 2} where {V,K,M}
93+
if length(leftindexset) * length(rightindexset) == 0
94+
return Array{V,M + 2}(undef, ntuple(d -> 0, M + 2)...)
95+
end
9396
nl = length(first(leftindexset))
9497
nr = length(first(rightindexset))
9598
L = length(f.localdims)
@@ -113,6 +116,9 @@ function _batcheval_imp_for_batchevaluator(f::CachedFunction{V,K},
113116
rightindexset::AbstractVector{MultiIndex},
114117
::Val{M}
115118
)::Array{V,M + 2} where {V,K,M}
119+
if length(leftindexset) * length(rightindexset) == 0
120+
return Array{V,M + 2}(undef, ntuple(d -> 0, M + 2)...)
121+
end
116122
nl = length(first(leftindexset))
117123
nr = length(first(rightindexset))
118124
L = length(f.localdims)

src/contraction.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -487,14 +487,14 @@ end
487487

488488
"""
489489
function contract(
490-
A::TensorTrain{ValueType,4},
491-
B::TensorTrain{ValueType,4};
490+
A::TensorTrain{V1,4},
491+
B::TensorTrain{V2,4};
492492
algorithm::Symbol=:TCI,
493493
tolerance::Float64=1e-12,
494494
maxbonddim::Int=typemax(Int),
495495
f::Union{Nothing,Function}=nothing,
496496
kwargs...
497-
) where {ValueType}
497+
) where {V1,V2}
498498
499499
Contract two tensor trains `A` and `B`.
500500
@@ -513,45 +513,48 @@ Arguments:
513513
- `kwargs...` are forwarded to [`crossinterpolate2`](@ref) if `algorithm=:TCI`.
514514
"""
515515
function contract(
516-
A::TensorTrain{ValueType,4},
517-
B::TensorTrain{ValueType,4};
516+
A::TensorTrain{V1,4},
517+
B::TensorTrain{V2,4};
518518
algorithm::Symbol=:TCI,
519519
tolerance::Float64=1e-12,
520520
maxbonddim::Int=typemax(Int),
521521
f::Union{Nothing,Function}=nothing,
522522
kwargs...
523-
) where {ValueType}
523+
)::TensorTrain{promote_type(V1,V2),4} where {V1,V2}
524+
Vres = promote_type(V1, V2)
525+
A_ = TensorTrain{Vres,4}(A)
526+
B_ = TensorTrain{Vres,4}(B)
524527
if algorithm === :TCI
525-
return contract_TCI(A, B; tolerance=tolerance, maxbonddim=maxbonddim, f=f, kwargs...)
528+
return contract_TCI(A_, B_; tolerance=tolerance, maxbonddim=maxbonddim, f=f, kwargs...)
526529
elseif algorithm === :naive
527530
if f !== nothing
528531
error("Naive contraction implementation cannot contract matrix product with a function. Use algorithm=:TCI instead.")
529532
end
530-
return contract_naive(A, B; tolerance=tolerance, maxbonddim=maxbonddim)
533+
return contract_naive(A_, B_; tolerance=tolerance, maxbonddim=maxbonddim)
531534
elseif algorithm === :zipup
532535
if f !== nothing
533536
error("Zipup contraction implementation cannot contract matrix product with a function. Use algorithm=:TCI instead.")
534537
end
535-
return contract_zipup(A, B; tolerance, maxbonddim)
538+
return contract_zipup(A_, B_; tolerance, maxbonddim)
536539
else
537540
throw(ArgumentError("Unknown algorithm $algorithm."))
538541
end
539542
end
540543

541544
function contract(
542545
A::Union{TensorCI1{V},TensorCI2{V},TensorTrain{V,3}},
543-
B::TensorTrain{V,4};
546+
B::TensorTrain{V2,4};
544547
kwargs...
545-
)::TensorTrain{V,3} where {V}
548+
)::TensorTrain{promote_type(V,V2),3} where {V,V2}
546549
tt = contract(TensorTrain{4}(A, [(1, s...) for s in sitedims(A)]), B; kwargs...)
547550
return TensorTrain{3}(tt, prod.(sitedims(tt)))
548551
end
549552

550553
function contract(
551554
A::TensorTrain{V,4},
552-
B::Union{TensorCI1{V},TensorCI2{V},TensorTrain{V,3}};
555+
B::Union{TensorCI1{V2},TensorCI2{V2},TensorTrain{V2,3}};
553556
kwargs...
554-
)::TensorTrain{V,3} where {V}
557+
)::TensorTrain{promote_type(V,V2),3} where {V,V2}
555558
tt = contract(A, TensorTrain{4}(B, [(s..., 1) for s in sitedims(B)]); kwargs...)
556559
return TensorTrain{3}(tt, prod.(sitedims(tt)))
557560
end

src/tensorci2.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ function TensorCI2{ValueType}(
4747
) where {F,ValueType,N}
4848
tci = TensorCI2{ValueType}(localdims)
4949
addglobalpivots!(tci, initialpivots)
50-
tci.maxsamplevalue = maximum(abs, func.(initialpivots))
50+
tci.maxsamplevalue = maximum(abs, (func(x) for x in initialpivots))
5151
abs(tci.maxsamplevalue) > 0.0 || error("maxsamplevalue is zero!")
5252
invalidatesitetensors!(tci)
5353
return tci

src/tensortrain.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ mutable struct TensorTrain{ValueType,N} <: AbstractTensorTrain{ValueType}
3030
end
3131
end
3232

33+
function Base.show(io::IO, obj::TensorTrain{V,N}) where {V,N}
34+
print(
35+
io,
36+
"$(typeof(obj)) of rank $(maximum(linkdims(obj)))"
37+
)
38+
end
39+
40+
function TensorTrain{V2,N}(tt::TensorTrain{V})::TensorTrain{V2,N} where {V,V2,N}
41+
return TensorTrain{V2,N}(Array{V2}.(sitetensors(tt)))
42+
end
43+
3344
"""
3445
function TensorTrain(sitetensors::Vector{Array{V, 3}}) where {V}
3546
@@ -53,7 +64,7 @@ function TensorTrain(tci::AbstractTensorTrain{V})::TensorTrain{V,3} where {V}
5364
end
5465

5566
"""
56-
function TensorTrain{N}(tci::AbstractTensorTrain{V}) where {V,N}
67+
function TensorTrain{V2,N}(tci::AbstractTensorTrain{V}) where {V,V2,N}
5768
5869
Convert a tensor-train-like object into a tensor train.
5970
@@ -62,15 +73,15 @@ Arguments:
6273
- `localdims`: a vector of local dimensions for each tensor in the tensor train. A each element
6374
of `localdims` should be an array-like object of `N-2` integers.
6475
"""
65-
function TensorTrain{V,N}(tt::AbstractTensorTrain{V}, localdims)::TensorTrain{V,N} where {V,N}
76+
function TensorTrain{V2,N}(tt::AbstractTensorTrain{V}, localdims)::TensorTrain{V2,N} where {V,V2,N}
6677
for d in localdims
6778
length(d) == N - 2 || error("Each element of localdims be a list of N-2 integers.")
6879
end
6980
for n in 1:length(tt)
7081
prod(size(tt[n])[2:end-1]) == prod(localdims[n]) || error("The local dimensions at n=$n must match the tensor sizes.")
7182
end
72-
return TensorTrain{V,N}(
73-
[reshape(t, size(t, 1), localdims[n]..., size(t)[end]) for (n, t) in enumerate(sitetensors(tt))])
83+
return TensorTrain{V2,N}(
84+
[reshape(Array{V2}(t), size(t, 1), localdims[n]..., size(t)[end]) for (n, t) in enumerate(sitetensors(tt))])
7485
end
7586

7687
function TensorTrain{N}(tt::AbstractTensorTrain{V}, localdims)::TensorTrain{V,N} where {V,N}
@@ -262,4 +273,14 @@ end
262273
function (obj::TensorTrainFit{ValueType})(x::Vector{ValueType}) where {ValueType}
263274
tensors = to_tensors(obj, x)
264275
return sum((abs2(_evaluate(tensors, indexset) - obj.values[i]) for (i, indexset) in enumerate(obj.indexsets)))
276+
end
277+
278+
279+
280+
function fulltensor(obj::TensorTrain{T,N})::Array{T} where {T,N}
281+
sitedims_ = sitedims(obj)
282+
localdims = collect(prod.(sitedims_))
283+
r = [obj(collect(Tuple(i))) for i in CartesianIndices(Tuple(localdims))]
284+
returnsize = collect(Iterators.flatten(sitedims_))
285+
return reshape(r, returnsize...)
265286
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ include("test_tensortrain.jl")
2020
include("test_conversion.jl")
2121
include("test_contraction.jl")
2222
include("test_integration.jl")
23-
include("test_globalsearch.jl")
23+
include("test_globalsearch.jl")

test/test_batcheval.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,21 @@ end
6565

6666
@test result ref
6767
end
68-
end
68+
69+
@testset "ThreadedBatchEvaluator (from Matsuura)" begin
70+
function f(x)
71+
sleep(1e-3)
72+
return sum(x)
73+
end
74+
75+
L = 20
76+
localdims = fill(2, L)
77+
parf = TCI.ThreadedBatchEvaluator{Float64}(f, localdims)
78+
79+
tci, ranks, errors = TCI.crossinterpolate2(Float64, parf, localdims)
80+
81+
tci_ref, ranks_ref, errors_ref = TCI.crossinterpolate2(Float64, f, localdims)
82+
83+
@test TCI.fulltensor(TCI.TensorTrain(tci)) TCI.fulltensor(TCI.TensorTrain(tci_ref))
84+
end
85+
end

test/test_tensortrain.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ end
178178
@test ttmultileg2.(indicesmultileg) 2 .* ttmultileg.(indicesmultileg)
179179
end
180180

181-
182181
@testset "norm" begin
183182
T = Float64
184183
sitedims_ = [[2], [2], [2]]
@@ -217,4 +216,20 @@ end
217216
tt_compressed = deepcopy(tt)
218217
TCI.compress!(tt_compressed, :SVD; tolerance=LA.norm(tt) * tol, normalizeerror=false)
219218
@test sqrt(LA.norm2(tt - tt_compressed) / LA.norm2(tt)) < sqrt(N) * tol
220-
end
219+
end
220+
221+
@testset "tensor train cast" begin
222+
Random.seed!(10)
223+
localdims = [2, 2, 2]
224+
linkdims_ = [1, 2, 3, 1]
225+
L = length(localdims)
226+
227+
tt1 = TCI.TensorTrain{Float64,3}([randn(Float64, linkdims_[n], localdims[n], linkdims_[n+1]) for n in 1:L])
228+
229+
tt2 = TCI.TensorTrain{ComplexF64,3}(tt1, localdims)
230+
@test TCI.fulltensor(tt1) TCI.fulltensor(tt2)
231+
232+
tt3 = TCI.TensorTrain{Float64,3}(tt2, localdims)
233+
@test TCI.fulltensor(tt1) TCI.fulltensor(tt3)
234+
end
235+

0 commit comments

Comments
 (0)