Skip to content

Commit 12715e2

Browse files
committed
Allow custom global pivot finders
1 parent 883a8fc commit 12715e2

File tree

3 files changed

+90
-43
lines changed

3 files changed

+90
-43
lines changed

src/TensorCrossInterpolation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ include("batcheval.jl")
3333
include("cachedfunction.jl")
3434
include("tensorci1.jl")
3535
include("tensorci2.jl")
36+
include("globalpivotfinder.jl")
3637
include("tensortrain.jl")
3738
include("conversion.jl")
3839
include("integration.jl")

src/tensorci2.jl

Lines changed: 26 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
Abstract type for global pivot finders in TCI2 algorithm.
3+
"""
4+
abstract type AbstractGlobalPivotFinder end
15

26
"""
37
mutable struct TensorCI2{ValueType} <: AbstractTensorTrain{ValueType}
@@ -664,9 +668,10 @@ Arguments:
664668
- `loginterval::Int` can be set to `>= 1` to specify how frequently to print convergence information. Default: `10`.
665669
- `normalizeerror::Bool` determines whether to scale the error by the maximum absolute value of `f` found during sampling. If set to `false`, the algorithm continues until the *absolute* error is below `tolerance`. If set to `true`, the algorithm uses the absolute error divided by the maximum sample instead. This is helpful if the magnitude of the function is not known in advance. Default: `true`.
666670
- `ncheckhistory::Int` is the number of history points to use for convergence checks. Default: `3`.
667-
- `maxnglobalpivot::Int` can be set to `>= 0`. Default: `5`.
668-
- `nsearchglobalpivot::Int` can be set to `>= 0`. Default: `5`.
669-
- `tolmarginglobalsearch` can be set to `>= 1.0`. Seach global pivots where the interpolation error is larger than the tolerance by `tolmarginglobalsearch`. Default: `10.0`.
671+
- `globalpivotfinder::Union{AbstractGlobalPivotFinder, Nothing}` is a global pivot finder to use for searching global pivots. Default: `nothing`. If `nothing`, a default global pivot finder is used.
672+
- `maxnglobalpivot::Int` can be set to `>= 0`. Default: `5`. The maximum number of global pivots to add in each iteration.
673+
- `nsearchglobalpivot::Int` can be set to `>= 0`. Default: `5`. This parameter is used for the default global pivot finder. Deprecated.
674+
- `tolmarginglobalsearch` can be set to `>= 1.0`. Seach global pivots where the interpolation error is larger than the tolerance by `tolmarginglobalsearch`. Default: `10.0`. This parameter is used for the default global pivot finder. Deprecated.
670675
- `strictlynested::Bool` determines whether to preserve partial nesting in the TCI algorithm. Default: `false`.
671676
- `checkbatchevaluatable::Bool` Check if the function `f` is batch evaluatable. Default: `false`.
672677
@@ -690,6 +695,7 @@ function optimize!(
690695
loginterval::Int=10,
691696
normalizeerror::Bool=true,
692697
ncheckhistory::Int=3,
698+
globalpivotfinder::Union{AbstractGlobalPivotFinder, Nothing}=nothing,
693699
maxnglobalpivot::Int=5,
694700
nsearchglobalpivot::Int=5,
695701
tolmarginglobalsearch::Float64=10.0,
@@ -705,9 +711,6 @@ function optimize!(
705711
error("Function `f` is not batch evaluatable")
706712
end
707713

708-
#if maxnglobalpivot > 0 && nsearchglobalpivot > 0
709-
#!strictlynested || error("nglobalpivots > 0 requires strictlynested=false!")
710-
#end
711714
if nsearchglobalpivot > 0 && nsearchglobalpivot < maxnglobalpivot
712715
error("nsearchglobalpivot < maxnglobalpivot!")
713716
end
@@ -734,6 +737,17 @@ function optimize!(
734737
))
735738
end
736739

740+
# Create the global pivot finder
741+
finder = if isnothing(globalpivotfinder)
742+
DefaultGlobalPivotFinder(
743+
nsearch=nsearchglobalpivot,
744+
maxnglobalpivot=maxnglobalpivot,
745+
tolmarginglobalsearch=tolmarginglobalsearch
746+
)
747+
else
748+
globalpivotfinder
749+
end
750+
737751
globalpivots = MultiIndex[]
738752
for iter in 1:maxiter
739753
errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0
@@ -771,12 +785,9 @@ function optimize!(
771785
end
772786

773787
# Find global pivots where the error is too large
774-
# Such gloval pivots are added to the TCI, invalidating site tensors.
775-
globalpivots = searchglobalpivots(
776-
tci, f, tolmarginglobalsearch * abstol,
777-
verbosity=verbosity,
778-
maxnglobalpivot=maxnglobalpivot,
779-
nsearch=nsearchglobalpivot
788+
globalpivots = finder(
789+
tci, f, abstol,
790+
verbosity=verbosity
780791
)
781792
addglobalpivots!(tci, globalpivots)
782793
push!(nglobalpivots, length(globalpivots))
@@ -889,20 +900,7 @@ end
889900
f,
890901
localdims::Union{Vector{Int},NTuple{N,Int}},
891902
initialpivots::Vector{MultiIndex}=[ones(Int, length(localdims))];
892-
tolerance::Float64=1e-8,
893-
pivottolerance::Float64=tolerance,
894-
maxbonddim::Int=typemax(Int),
895-
maxiter::Int=200,
896-
sweepstrategy::Symbol=:backandforth,
897-
pivotsearch::Symbol=:full,
898-
verbosity::Int=0,
899-
loginterval::Int=10,
900-
normalizeerror::Bool=true,
901-
ncheckhistory=3,
902-
maxnglobalpivot::Int=5,
903-
nsearchglobalpivot::Int=5,
904-
tolmarginglobalsearch::Float64=10.0,
905-
strictlynested::Bool=false
903+
kwargs...
906904
) where {ValueType,N}
907905
908906
Cross interpolate a function ``f(\mathbf{u})`` using the TCI2 algorithm. Here, the domain of ``f`` is ``\mathbf{u} \in [1, \ldots, d_1] \times [1, \ldots, d_2] \times \ldots \times [1, \ldots, d_{\mathscr{L}}]`` and ``d_1 \ldots d_{\mathscr{L}}`` are the local dimensions.
@@ -912,27 +910,13 @@ Arguments:
912910
- `f` is the function to be interpolated. `f` should have a single parameter, which is a vector of the same length as `localdims`. The return type should be `ValueType`.
913911
- `localdims::Union{Vector{Int},NTuple{N,Int}}` is a `Vector` (or `Tuple`) that contains the local dimension of each index of `f`.
914912
- `initialpivots::Vector{MultiIndex}` is a vector of pivots to be used for initialization. Default: `[1, 1, ...]`.
915-
- `tolerance::Float64` is a float specifying the target tolerance for the interpolation. Default: `1e-8`.
916-
- `pivottolerance::Float64` is a float that specifies the tolerance for adding new pivots, i.e. the truncation of tensor train bonds. It should be <= tolerance, otherwise convergence may be impossible. Default: `tolerance`.
917-
- `maxbonddim::Int` specifies the maximum bond dimension for the TCI. Default: `typemax(Int)`, i.e. effectively unlimited.
918-
- `maxiter::Int` is the maximum number of iterations (i.e. optimization sweeps) before aborting the TCI construction. Default: `200`.
919-
- `sweepstrategy::Symbol` specifies whether to sweep forward (:forward), backward (:backward), or back and forth (:backandforth) during optimization. Default: `:backandforth`.
920-
- `pivotsearch::Symbol` determins how pivots are searched (`:full` or `:rook`). Default: `:full`.
921-
- `verbosity::Int` can be set to `>= 1` to get convergence information on standard output during optimization. Default: `0`.
922-
- `loginterval::Int` can be set to `>= 1` to specify how frequently to print convergence information. Default: `10`.
923-
- `normalizeerror::Bool` determines whether to scale the error by the maximum absolute value of `f` found during sampling. If set to `false`, the algorithm continues until the *absolute* error is below `tolerance`. If set to `true`, the algorithm uses the absolute error divided by the maximum sample instead. This is helpful if the magnitude of the function is not known in advance. Default: `true`.
924-
- `ncheckhistory::Int` is the number of history points to use for convergence checks. Default: `3`.
925-
- `maxnglobalpivot::Int` can be set to `>= 0`. Default: `5`.
926-
- `nsearchglobalpivot::Int` can be set to `>= 0`. Default: `5`.
927-
- `tolmarginglobalsearch` can be set to `>= 1.0`. Seach global pivots where the interpolation error is larger than the tolerance by `tolmarginglobalsearch`. Default: `10.0`.
928-
- `strictlynested::Bool=false` determines whether to preserve partial nesting in the TCI algorithm. Default: `true`.
929-
- `checkbatchevaluatable::Bool` Check if the function `f` is batch evaluatable. Default: `false`.
913+
914+
Refer to [`optimize!`](@ref) for other keyword arguments such as `tolerance`, `maxbonddim`, `maxiter`.
930915
931916
Notes:
932917
- Set `tolerance` to be > 0 or `maxbonddim` to some reasonable value. Otherwise, convergence is not reachable.
933918
- By default, no caching takes place. Use the [`CachedFunction`](@ref) wrapper if your function is expensive to evaluate.
934919
935-
936920
See also: [`optimize!`](@ref), [`optfirstpivot`](@ref), [`CachedFunction`](@ref), [`crossinterpolate1`](@ref)
937921
"""
938922
function crossinterpolate2(
@@ -947,7 +931,6 @@ function crossinterpolate2(
947931
return tci, ranks, errors
948932
end
949933

950-
951934
"""
952935
Search global pivots where the interpolation error exceeds `abstol`.
953936
"""

test/test_tensorci2.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,69 @@ import QuanticsGrids as QD
9999

100100

101101
end
102+
103+
104+
struct CustomGlobalPivotFinder <: TCI.AbstractGlobalPivotFinder
105+
npivots::Int
106+
end
107+
108+
function (finder::CustomGlobalPivotFinder)(
109+
tci::TensorCI2{ValueType},
110+
f,
111+
abstol::Float64;
112+
verbosity::Int=0
113+
)::Vector{MultiIndex} where {ValueType}
114+
L = length(tci.localdims)
115+
return [[rand(1:tci.localdims[p]) for p in 1:L] for _ in 1:finder.npivots]
116+
end
117+
118+
@testset "custom global pivot finder" begin
119+
pivotsearch = :full
120+
strictlynested = false
121+
nsearchglobalpivot = 10
122+
123+
# f(x) = exp(-x)
124+
Random.seed!(1240)
125+
R = 8
126+
abstol = 1e-4
127+
128+
grid = QD.DiscretizedGrid{1}(R, (0.0,), (1.0,))
129+
130+
#index_to_x(i) = (i - 1) / 2^R # x ∈ [0, 1)
131+
fx(x) = exp(-x)
132+
f(bitlist::MultiIndex) = fx(QD.quantics_to_origcoord(grid, bitlist)[1])
133+
134+
localdims = fill(2, R)
135+
firstpivots = [ones(Int, R), vcat(1, fill(2, R - 1))]
136+
tci, ranks, errors = crossinterpolate2(
137+
Float64,
138+
f,
139+
localdims,
140+
firstpivots;
141+
tolerance=abstol,
142+
maxbonddim=1,
143+
maxiter=2,
144+
loginterval=1,
145+
verbosity=0,
146+
normalizeerror=false,
147+
globalpivotfinder=CustomGlobalPivotFinder(10),
148+
pivotsearch=pivotsearch,
149+
strictlynested=strictlynested
150+
)
151+
152+
@test all(TCI.linkdims(tci) .== 1)
153+
154+
# Conversion to TT
155+
tt = TCI.TensorTrain(tci)
156+
157+
for x in [0.1, 0.3, 0.6, 0.9]
158+
indexset = QD.origcoord_to_quantics(
159+
grid, (x,)
160+
)
161+
@test abs(TCI.evaluate(tci, indexset) - f(indexset)) < abstol
162+
@test abs(TCI.evaluate(tt, indexset) - f(indexset)) < abstol
163+
end
164+
end
102165

103166
@testset "trivial MPS(exp), small maxbonddim" begin
104167
pivotsearch = :full

0 commit comments

Comments
 (0)