Skip to content

Commit 8924fb5

Browse files
committed
Merge branch '48-document-global-pivot-search-and-activate-it-by-default' into 'main'
Resolve "Document global pivot search and activate it by default" See merge request tensors4fields/TensorCrossInterpolation.jl!92
2 parents 4e3c0d0 + ae8967a commit 8924fb5

File tree

4 files changed

+67
-88
lines changed

4 files changed

+67
-88
lines changed

docs/src/index.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,24 @@ tci, ranks, errors = TCI.crossinterpolate2(
117117
```
118118
This algorithm optimizes the given index set (in this case `[1, 2, 3, 4, 5]`) by searching for a maximum absolute value, alternating through the dimensions. If no starting point is given, `[1, 1, ...]` is used.
119119

120+
## Combing TCI2 and global pivot search
121+
The main algorithm for adding new pivots in TCI2 is the 2-site algorithm, which is local.
122+
The 2-site algorithm alone may miss some regions with high interpolation error.
123+
124+
The current TCI2 implementation provides the combination of the 2-site algorithm and a global search algorithm to find such regions.
125+
This functionality is activated by default.
126+
In the function [`crossinterpolate2`](@ref), we alternate between a 2-site-update sweep and a global pivot insertion.
127+
After a 2-site-update sweep, we search for index sets with high interpolation errors (> the given tolerance multiplied by the parameter `tolmarginglobalsearch`) and add them to the TCI2 object, and then we continue with a 2-site-update sweep.
128+
129+
The number of initial points used in one global search is controlled by the parameter `nsearchglobalpivot`.
130+
You may consider increasing this number if the global search is not effective (check the number of pivots found and timings of the global search by setting `verbosity` to a higher value!).
131+
The maximum number of global pivots inserted at once is controlled by the parameter `maxnglobalpivot`.
132+
133+
A rare failure case is that the global search find the index sets with high interpolation errors, but the 2-site algorithm fails to add these pivots into the TCI2 object.
134+
This will end up adding the same index sets in the next global search, leading to an endless loop.
135+
120136
## Estiamte true interpolation error by random global search
121-
Since the TCI update algorithms are local, the true interpolation error is not known. However, the error can be estimated by global searches. This is implemented in the function `estimatetrueerror`:
137+
Since most of the TCI update algorithms are local, the true interpolation error is not known. However, the error can be estimated by global searches. This is implemented in the function [`estimatetrueerror`](@ref):
122138

123139
```julia
124140
pivoterrors = TCI.estimatetrueerror(TCI.TensorTrain(tci), f)

src/globalsearch.jl

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function estimatetrueerror(
3232

3333
ttcache = TTCache(tt)
3434

35-
pivoterror = [_floatingzone(ttcache, f, initp) for initp in initialpoints]
35+
pivoterror = [_floatingzone(ttcache, f; initp=initp) for initp in initialpoints]
3636

3737
p = sortperm([e for (_, e) in pivoterror], rev=true)
3838

@@ -41,14 +41,25 @@ end
4141

4242

4343
function _floatingzone(
44-
ttcache::TTCache{ValueType}, f, pivot
44+
ttcache::TTCache{ValueType}, f;
45+
earlystoptol::Float64 = typemax(Float64),
46+
nsweeps=typemax(Int), initp::Union{Nothing,MultiIndex} = nothing
4547
)::Tuple{MultiIndex,Float64} where {ValueType}
48+
nsweeps > 0 || error("nsweeps should be positive!")
49+
50+
localdims = first.(sitedims(ttcache))
51+
4652
n = length(ttcache)
4753

54+
if initp === nothing
55+
pivot = [rand(1:d) for d in localdims]
56+
else
57+
pivot = initp
58+
end
59+
4860
maxerror = abs(f(pivot) - ttcache(pivot))
49-
localdims = first.(sitedims(ttcache))
5061

51-
while true
62+
for isweep in 1:nsweeps
5263
prev_maxerror = maxerror
5364
for ipos in 1:n
5465
exactdata = filltensor(
@@ -72,10 +83,28 @@ function _floatingzone(
7283
maxerror = maximum(err)
7384
end
7485

75-
if maxerror == prev_maxerror
86+
if maxerror == prev_maxerror || maxerror > earlystoptol # early stop
7687
break
7788
end
7889
end
7990

8091
return pivot, maxerror
92+
end
93+
94+
95+
function fillsitetensors!(
96+
tci::TensorCI2{ValueType}, f) where {ValueType}
97+
for b in 1:length(tci)
98+
setsitetensor!(tci, f, b)
99+
end
100+
nothing
101+
end
102+
103+
104+
function _sanitycheck(tci::TensorCI2{ValueType})::Bool where {ValueType}
105+
for b in 1:length(tci)-1
106+
length(tci.Iset[b+1]) == length(tci.Jset[b]) || error("Pivot matrix at bond $(b) is not square!")
107+
end
108+
109+
return true
81110
end

src/integration.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ function integrate(
2323
a::Vector{ValueType},
2424
b::Vector{ValueType};
2525
GKorder::Int=15,
26-
normalizeerror=false,
2726
kwargs...
2827
) where {ValueType}
2928
if iseven(GKorder)
@@ -52,7 +51,6 @@ function integrate(
5251
F,
5352
localdims;
5453
nsearchglobalpivot=10,
55-
normalizeerror,
5654
kwargs...
5755
)
5856

src/tensorci2.jl

Lines changed: 16 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ function addglobalpivots2sitesweep!(
214214
pivotsearch::Symbol=:full,
215215
verbosity::Int=0,
216216
ntry::Int=10,
217-
strictlynested::Bool=true
217+
strictlynested::Bool=false
218218
)::Int where {F,ValueType}
219219
if any(length(tci) .!= length.(pivots))
220220
throw(DimensionMismatch("Please specify a pivot as one index per leg of the MPS."))
@@ -606,9 +606,9 @@ end
606606
normalizeerror::Bool=true,
607607
ncheckhistory=3,
608608
maxnglobalpivot::Int=5,
609-
nsearchglobalpivot::Int=0,
609+
nsearchglobalpivot::Int=5,
610610
tolmarginglobalsearch::Float64=10.0,
611-
strictlynested::Bool=true
611+
strictlynested::Bool=false
612612
) where {ValueType}
613613
614614
Perform optimization sweeps using the TCI2 algorithm. This will sucessively improve the TCI approximation of a function until it fits `f` with an error smaller than `tolerance`, or until the maximum bond dimension (`maxbonddim`) is reached.
@@ -629,7 +629,7 @@ Arguments:
629629
- `normalizeerror::Bool` determines whether to scale the error by the maximum absolute value of `f` found during sampling. If set to `false`, the algorithm continues until the *absolute* error is below `tolerance`. If set to `true`, the algorithm uses the absolute error divided by the maximum sample instead. This is helpful if the magnitude of the function is not known in advance. Default: `true`.
630630
- `ncheckhistory::Int` is the number of history points to use for convergence checks. Default: `3`.
631631
- `maxnglobalpivot::Int` can be set to `>= 0`. Default: `5`.
632-
- `nsearchglobalpivot::Int` can be set to `>= 0`. Default: `0`.
632+
- `nsearchglobalpivot::Int` can be set to `>= 0`. Default: `5`.
633633
- `tolmarginglobalsearch` can be set to `>= 1.0`. Seach global pivots where the interpolation error is larger than the tolerance by `tolmarginglobalsearch`. Default: `10.0`.
634634
- `strictlynested::Bool` determines whether to preserve partial nesting in the TCI algorithm. Default: `false`.
635635
- `checkbatchevaluatable::Bool` Check if the function `f` is batch evaluatable. Default: `false`.
@@ -655,7 +655,7 @@ function optimize!(
655655
normalizeerror::Bool=true,
656656
ncheckhistory::Int=3,
657657
maxnglobalpivot::Int=5,
658-
nsearchglobalpivot::Int=0,
658+
nsearchglobalpivot::Int=5,
659659
tolmarginglobalsearch::Float64=10.0,
660660
strictlynested::Bool=false,
661661
checkbatchevaluatable::Bool=false
@@ -704,10 +704,11 @@ function optimize!(
704704
sweepstrategy=sweepstrategy,
705705
fillsitetensors=true
706706
)
707-
if verbosity > 0 && length(globalpivots) > 0
708-
nrejections = length([p for p in globalpivots if abs(evaluate(tci, p) - f(p)) > abstol])
707+
if verbosity > 0 && length(globalpivots) > 0 && mod(iter, loginterval) == 0
708+
abserr = [abs(evaluate(tci, p) - f(p)) for p in globalpivots]
709+
nrejections = length(abserr .> abstol)
709710
if nrejections > 0
710-
println(" Rejected $(nrejections) global pivots added in the previous iteration")
711+
println(" Rejected $(nrejections) global pivots added in the previous iteration, errors are $(abserr)")
711712
flush(stdout)
712713
end
713714
end
@@ -776,7 +777,7 @@ function sweep2site!(
776777
sweepstrategy::Symbol=:backandforth,
777778
pivotsearch::Symbol=:full,
778779
verbosity::Int=0,
779-
strictlynested::Bool=true,
780+
strictlynested::Bool=false,
780781
fillsitetensors::Bool=true
781782
) where {ValueType}
782783
invalidatesitetensors!(tci)
@@ -850,9 +851,9 @@ end
850851
normalizeerror::Bool=true,
851852
ncheckhistory=3,
852853
maxnglobalpivot::Int=5,
853-
nsearchglobalpivot::Int=0,
854+
nsearchglobalpivot::Int=5,
854855
tolmarginglobalsearch::Float64=10.0,
855-
strictlynested::Bool=true
856+
strictlynested::Bool=false
856857
) where {ValueType,N}
857858
858859
Cross interpolate a function ``f(\mathbf{u})`` using the TCI2 algorithm. Here, the domain of ``f`` is ``\mathbf{u} \in [1, \ldots, d_1] \times [1, \ldots, d_2] \times \ldots \times [1, \ldots, d_{\mathscr{L}}]`` and ``d_1 \ldots d_{\mathscr{L}}`` are the local dimensions.
@@ -873,9 +874,9 @@ Arguments:
873874
- `normalizeerror::Bool` determines whether to scale the error by the maximum absolute value of `f` found during sampling. If set to `false`, the algorithm continues until the *absolute* error is below `tolerance`. If set to `true`, the algorithm uses the absolute error divided by the maximum sample instead. This is helpful if the magnitude of the function is not known in advance. Default: `true`.
874875
- `ncheckhistory::Int` is the number of history points to use for convergence checks. Default: `3`.
875876
- `maxnglobalpivot::Int` can be set to `>= 0`. Default: `5`.
876-
- `nsearchglobalpivot::Int` can be set to `>= 0`. Default: `0`.
877+
- `nsearchglobalpivot::Int` can be set to `>= 0`. Default: `5`.
877878
- `tolmarginglobalsearch` can be set to `>= 1.0`. Seach global pivots where the interpolation error is larger than the tolerance by `tolmarginglobalsearch`. Default: `10.0`.
878-
- `strictlynested::Bool=true` determines whether to preserve partial nesting in the TCI algorithm. Default: `true`.
879+
- `strictlynested::Bool=false` determines whether to preserve partial nesting in the TCI algorithm. Default: `true`.
879880
- `checkbatchevaluatable::Bool` Check if the function `f` is batch evaluatable. Default: `false`.
880881
881882
Notes:
@@ -916,8 +917,9 @@ function searchglobalpivots(
916917
end
917918

918919
pivots = Dict{Float64,MultiIndex}()
920+
ttcache = TTCache(tci)
919921
for _ in 1:nsearch
920-
pivot, error = _floatingzone(tci, f, 10 * abstol)
922+
pivot, error = _floatingzone(ttcache, f; earlystoptol = 10 * abstol, nsweeps=100)
921923
if error > abstol
922924
pivots[error] = pivot
923925
end
@@ -941,69 +943,3 @@ function searchglobalpivots(
941943
return [p for (_,p) in pivots]
942944
end
943945

944-
945-
function _floatingzone(
946-
tci::TensorCI2{ValueType}, f, abstol;
947-
nsweeps=100
948-
)::Tuple{MultiIndex,Float64} where {ValueType}
949-
nsweeps > 0 || error("nsweeps should be positive!")
950-
951-
localdims = tci.localdims
952-
953-
n = length(tci)
954-
955-
ttcache = TTCache(tci)
956-
957-
pivot = [rand(1:d) for d in localdims]
958-
959-
maxerror = abs(f(pivot) - ttcache(pivot))
960-
961-
for isweep in 1:nsweeps
962-
prev_maxerror = maxerror
963-
for ipos in 1:n
964-
exactdata = filltensor(
965-
ValueType,
966-
f,
967-
tci.localdims,
968-
[pivot[1:ipos-1]],
969-
[pivot[ipos+1:end]],
970-
Val(1)
971-
)
972-
prediction = filltensor(
973-
ValueType,
974-
ttcache,
975-
tci.localdims,
976-
[pivot[1:ipos-1]],
977-
[pivot[ipos+1:end]],
978-
Val(1)
979-
)
980-
err = vec(abs.(exactdata .- prediction))
981-
pivot[ipos] = argmax(err)
982-
maxerror = maximum(err)
983-
end
984-
985-
if maxerror == prev_maxerror || maxerror > abstol # early stop
986-
break
987-
end
988-
end
989-
990-
return pivot, maxerror
991-
end
992-
993-
994-
function fillsitetensors!(
995-
tci::TensorCI2{ValueType}, f) where {ValueType}
996-
for b in 1:length(tci)
997-
setsitetensor!(tci, f, b)
998-
end
999-
nothing
1000-
end
1001-
1002-
1003-
function _sanitycheck(tci::TensorCI2{ValueType})::Bool where {ValueType}
1004-
for b in 1:length(tci)-1
1005-
length(tci.Iset[b+1]) == length(tci.Jset[b]) || error("Pivot matrix at bond $(b) is not square!")
1006-
end
1007-
1008-
return true
1009-
end

0 commit comments

Comments
 (0)