Skip to content

Commit 1a37ad4

Browse files
authored
Merge pull request #54 from tensor4all/reuse-local-pivots-in-tensorci2-constructor
add new constructor for TensorCI2 to initialize with local pivots list
2 parents 3a42110 + cc85787 commit 1a37ad4

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

src/tensorci2.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,25 @@ function TensorCI2{ValueType}(
5353
return tci
5454
end
5555

56+
"""
57+
Initialize a TCI2 object with local pivot lists.
58+
"""
59+
function TensorCI2{ValueType}(
60+
func::F,
61+
localdims::Union{Vector{Int},NTuple{N,Int}},
62+
Iset::Vector{Vector{MultiIndex}},
63+
Jset::Vector{Vector{MultiIndex}}
64+
) where {F,ValueType,N}
65+
tci = TensorCI2{ValueType}(localdims)
66+
tci.Iset = Iset
67+
tci.Jset = Jset
68+
pivots = reconstractglobalpivotsfromijset(localdims, tci.Iset, tci.Jset)
69+
tci.maxsamplevalue = maximum(abs, (func(bit) for bit in pivots))
70+
abs(tci.maxsamplevalue) > 0.0 || error("maxsamplevalue is zero!")
71+
invalidatesitetensors!(tci)
72+
return tci
73+
end
74+
5675
@doc raw"""
5776
function printnestinginfo(tci::TensorCI2{T}) where {T}
5877
@@ -150,6 +169,24 @@ function updateerrors!(
150169
nothing
151170
end
152171

172+
function reconstractglobalpivotsfromijset(
173+
localdims::Union{Vector{Int},NTuple{N,Int}},
174+
Isets::Vector{Vector{MultiIndex}},
175+
Jsets::Vector{Vector{MultiIndex}}
176+
) where {N}
177+
pivots = []
178+
l = length(Isets)
179+
for i in 1:l
180+
for Iset in Isets[i]
181+
for Jset in Jsets[i]
182+
for j in 1:localdims[i]
183+
pushunique!(pivots, vcat(Iset, [j], Jset))
184+
end
185+
end
186+
end
187+
end
188+
return pivots
189+
end
153190

154191
"""
155192
Add global pivots to index sets

test/test_tensorci2.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,22 @@ import QuanticsGrids as QD
393393
end
394394

395395

396+
@testset "initialize_with_local_pivots_list" begin
397+
Random.seed!(1234)
398+
399+
N = 10
400+
M = rand(Float64, N, N)
401+
f(v) = M[v[1], v[2]] # 2D function
402+
localdims = fill(N, 2)
403+
mbd = 5
404+
405+
tci, ranks, errors = TCI.crossinterpolate2(Float64, f, localdims; maxbonddim=mbd)
406+
tci2 = TCI.TensorCI2{Float64}(f, localdims, tci.Iset, tci.Jset)
407+
@test tci2.maxsamplevalue == tci.maxsamplevalue
408+
@test tci2.Iset == tci.Iset
409+
@test tci2.Jset == tci.Jset
410+
end
411+
396412
@testset "crossinterpolate2_ttcache" begin
397413
ValueType = Float64
398414

0 commit comments

Comments
 (0)