Skip to content

Commit 907c038

Browse files
committed
Introduce GlobalPivotSearchInput
1 parent 8b35171 commit 907c038

File tree

5 files changed

+214
-51
lines changed

5 files changed

+214
-51
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@ BitIntegers = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1"
88
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
11+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112

1213
[compat]
1314
BitIntegers = "0.3.5"
1415
EllipsisNotation = "1"
1516
QuadGK = "2.9"
17+
Random = "1.11.0"
1618
julia = "1.6"
1719

1820
[extras]

src/TensorCrossInterpolation.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import Base: ==, +
1313
# To define iterators and element access for MCI, TCI and TT objects
1414
import Base: isempty, iterate, getindex, lastindex, broadcastable
1515
import Base: length, size, sum
16+
import Random
1617

1718
export crossinterpolate1, crossinterpolate2, optfirstpivot
1819
export tensortrain, TensorTrain, sitedims, evaluate
@@ -31,10 +32,10 @@ include("abstracttensortrain.jl")
3132
include("cachedtensortrain.jl")
3233
include("batcheval.jl")
3334
include("cachedfunction.jl")
35+
include("tensortrain.jl")
3436
include("tensorci1.jl")
35-
include("tensorci2.jl")
3637
include("globalpivotfinder.jl")
37-
include("tensortrain.jl")
38+
include("tensorci2.jl")
3839
include("conversion.jl")
3940
include("integration.jl")
4041
include("contraction.jl")

src/globalpivotfinder.jl

Lines changed: 169 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,116 @@
1+
import Random: AbstractRNG, default_rng
12

23
"""
4+
GlobalPivotSearchInput{ValueType}
5+
6+
Input data structure for global pivot search algorithms.
7+
8+
# Fields
9+
- `localdims::Vector{Int}`: Dimensions of each tensor index
10+
- `current_tt::TensorTrain{ValueType,3}`: Current tensor train approximation
11+
- `maxsamplevalue::ValueType`: Maximum absolute value of the function
12+
- `Iset::Vector{Vector{Vector{Int}}}`: Set of left indices
13+
- `Jset::Vector{Vector{Vector{Int}}}`: Set of right indices
14+
"""
15+
struct GlobalPivotSearchInput{ValueType}
16+
localdims::Vector{Int}
17+
current_tt::TensorTrain{ValueType,3}
18+
maxsamplevalue::Float64
19+
Iset::Vector{Vector{MultiIndex}}
20+
Jset::Vector{Vector{MultiIndex}}
21+
22+
"""
23+
GlobalPivotSearchInput(
24+
localdims::Vector{Int},
25+
current_tt::TensorTrain{ValueType,3},
26+
maxsamplevalue::ValueType,
27+
Iset::Vector{Vector{MultiIndex}},
28+
Jset::Vector{Vector{MultiIndex}}
29+
) where {ValueType}
30+
31+
Construct a GlobalPivotSearchInput with the given fields.
32+
"""
33+
function GlobalPivotSearchInput{ValueType}(
34+
localdims::Vector{Int},
35+
current_tt::TensorTrain{ValueType,3},
36+
maxsamplevalue::Float64,
37+
Iset::Vector{Vector{MultiIndex}},
38+
Jset::Vector{Vector{MultiIndex}}
39+
) where {ValueType}
40+
new{ValueType}(
41+
localdims,
42+
current_tt,
43+
maxsamplevalue,
44+
Iset,
45+
Jset
46+
)
47+
end
48+
end
49+
50+
51+
"""
52+
AbstractGlobalPivotFinder
53+
54+
Abstract type for global pivot finders that search for indices with high interpolation error.
55+
"""
56+
abstract type AbstractGlobalPivotFinder end
57+
58+
"""
59+
(finder::AbstractGlobalPivotFinder)(
60+
input::GlobalPivotSearchInput{ValueType},
61+
f,
62+
abstol::Float64;
63+
verbosity::Int=0,
64+
rng::AbstractRNG=Random.default_rng()
65+
)::Vector{MultiIndex} where {ValueType}
66+
67+
Find global pivots using the given finder algorithm.
68+
69+
# Arguments
70+
- `input`: Input data for the search algorithm
71+
- `f`: Function to be interpolated
72+
- `abstol`: Absolute tolerance for the interpolation error
73+
- `verbosity`: Verbosity level (default: 0)
74+
- `rng`: Random number generator (default: Random.default_rng())
75+
76+
# Returns
77+
- `Vector{MultiIndex}`: Set of indices with high interpolation error
78+
"""
79+
function (finder::AbstractGlobalPivotFinder)(
80+
input::GlobalPivotSearchInput{ValueType},
81+
f,
82+
abstol::Float64;
83+
verbosity::Int=0,
84+
rng::AbstractRNG=Random.default_rng()
85+
)::Vector{MultiIndex} where {ValueType}
86+
error("find_global_pivots not implemented for $(typeof(finder))")
87+
end
88+
89+
"""
90+
DefaultGlobalPivotFinder
91+
392
Default implementation of global pivot finder that uses random search.
93+
94+
# Fields
95+
- `nsearch::Int`: Number of initial points to search from
96+
- `maxnglobalpivot::Int`: Maximum number of pivots to add in each iteration
97+
- `tolmarginglobalsearch::Float64`: Search for pivots where the interpolation error is larger than the tolerance multiplied by this factor
498
"""
599
struct DefaultGlobalPivotFinder <: AbstractGlobalPivotFinder
6100
nsearch::Int
7101
maxnglobalpivot::Int
8102
tolmarginglobalsearch::Float64
9103
end
10104

11-
# Constructor for backward compatibility
105+
"""
106+
DefaultGlobalPivotFinder(;
107+
nsearch::Int=5,
108+
maxnglobalpivot::Int=5,
109+
tolmarginglobalsearch::Float64=10.0
110+
)
111+
112+
Construct a DefaultGlobalPivotFinder with the given parameters.
113+
"""
12114
function DefaultGlobalPivotFinder(;
13115
nsearch::Int=5,
14116
maxnglobalpivot::Int=5,
@@ -18,41 +120,76 @@ function DefaultGlobalPivotFinder(;
18120
end
19121

20122
"""
21-
Find global pivots where the interpolation error exceeds the tolerance.
22-
"""
23-
function (finder::AbstractGlobalPivotFinder)(
24-
tci::TensorCI2{ValueType},
25-
f,
26-
abstol::Float64;
27-
verbosity::Int=0
28-
)::Vector{MultiIndex} where {ValueType}
29-
error("find_global_pivots not implemented for $(typeof(finder))")
30-
end
123+
(finder::DefaultGlobalPivotFinder)(
124+
input::GlobalPivotSearchInput{ValueType},
125+
f,
126+
abstol::Float64;
127+
verbosity::Int=0,
128+
rng::AbstractRNG=Random.default_rng()
129+
)::Vector{MultiIndex} where {ValueType}
31130
32-
# Default implementation using the existing searchglobalpivots
131+
Find global pivots using random search.
132+
133+
# Arguments
134+
- `input`: Input data for the search algorithm
135+
- `f`: Function to be interpolated
136+
- `abstol`: Absolute tolerance for the interpolation error
137+
- `verbosity`: Verbosity level (default: 0)
138+
- `rng`: Random number generator (default: Random.default_rng())
139+
140+
# Returns
141+
- `Vector{MultiIndex}`: Set of indices with high interpolation error
142+
"""
33143
function (finder::DefaultGlobalPivotFinder)(
34-
tci::TensorCI2{ValueType},
144+
input::GlobalPivotSearchInput{ValueType},
35145
f,
36146
abstol::Float64;
37-
verbosity::Int=0
147+
verbosity::Int=0,
148+
rng::AbstractRNG=Random.default_rng()
38149
)::Vector{MultiIndex} where {ValueType}
39-
return searchglobalpivots(
40-
tci, f, finder.tolmarginglobalsearch * abstol,
41-
verbosity=verbosity,
42-
nsearch=finder.nsearch,
43-
maxnglobalpivot=finder.maxnglobalpivot
44-
)
45-
end
150+
L = length(input.localdims)
151+
nsearch = finder.nsearch
152+
maxnglobalpivot = finder.maxnglobalpivot
153+
tolmarginglobalsearch = finder.tolmarginglobalsearch
46154

47-
# Helper function for backward compatibility
48-
function _create_default_finder(;
49-
nsearch::Int=5,
50-
maxnglobalpivot::Int=5,
51-
tolmarginglobalsearch::Float64=10.0
52-
)
53-
return DefaultGlobalPivotFinder(
54-
nsearch=nsearch,
55-
maxnglobalpivot=maxnglobalpivot,
56-
tolmarginglobalsearch=tolmarginglobalsearch
57-
)
155+
# Generate random initial points
156+
initial_points = [[rand(rng, 1:input.localdims[p]) for p in 1:L] for _ in 1:nsearch]
157+
158+
# Find pivots with high interpolation error
159+
found_pivots = MultiIndex[]
160+
for point in initial_points
161+
# Perform local search from each initial point
162+
current_point = copy(point)
163+
best_error = 0.0
164+
best_point = copy(point)
165+
166+
# Local search
167+
for p in 1:L
168+
for v in 1:input.localdims[p]
169+
current_point[p] = v
170+
error = abs(f(current_point) - input.current_tt(current_point))
171+
if error > best_error
172+
best_error = error
173+
best_point = copy(current_point)
174+
end
175+
end
176+
current_point[p] = point[p] # Reset to original point
177+
end
178+
179+
# Add point if error is above threshold
180+
if best_error > abstol * tolmarginglobalsearch
181+
push!(found_pivots, best_point)
182+
end
183+
end
184+
185+
# Limit number of pivots
186+
if length(found_pivots) > maxnglobalpivot
187+
found_pivots = found_pivots[1:maxnglobalpivot]
188+
end
189+
190+
if verbosity > 0
191+
println("Found $(length(found_pivots)) global pivots")
192+
end
193+
194+
return found_pivots
58195
end

src/tensorci2.jl

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
"""
2-
Abstract type for global pivot finders in TCI2 algorithm.
3-
"""
4-
abstract type AbstractGlobalPivotFinder end
5-
61
"""
72
mutable struct TensorCI2{ValueType} <: AbstractTensorTrain{ValueType}
83
@@ -617,7 +612,8 @@ function convergencecriterion(
617612
nglobalpivots::AbstractVector{Int},
618613
tolerance::Float64,
619614
maxbonddim::Int,
620-
ncheckhistory::Int,
615+
ncheckhistory::Int;
616+
checkconvglobalpivot::Bool=true
621617
)::Bool
622618
if length(errors) < ncheckhistory
623619
return false
@@ -626,12 +622,27 @@ function convergencecriterion(
626622
lastngpivots = last(nglobalpivots, ncheckhistory)
627623
return (
628624
all(last(errors, ncheckhistory) .< tolerance) &&
629-
all(lastngpivots .== 0) &&
625+
(checkconvglobalpivot ? all(lastngpivots .== 0) : true) &&
630626
minimum(lastranks) == lastranks[end]
631627
) || all(lastranks .>= maxbonddim)
632628
end
633629

634630

631+
"""
632+
GlobalPivotSearchInput(tci::TensorCI2{ValueType}) where {ValueType}
633+
634+
Construct a GlobalPivotSearchInput from a TensorCI2 object.
635+
"""
636+
function GlobalPivotSearchInput(tci::TensorCI2{ValueType}) where {ValueType}
637+
return GlobalPivotSearchInput{ValueType}(
638+
tci.localdims,
639+
TensorTrain(tci),
640+
tci.maxsamplevalue,
641+
tci.Iset,
642+
tci.Jset
643+
)
644+
end
645+
635646

636647
"""
637648
function optimize!(
@@ -670,10 +681,14 @@ Arguments:
670681
- `ncheckhistory::Int` is the number of history points to use for convergence checks. Default: `3`.
671682
- `globalpivotfinder::Union{AbstractGlobalPivotFinder, Nothing}` is a global pivot finder to use for searching global pivots. Default: `nothing`. If `nothing`, a default global pivot finder is used.
672683
- `maxnglobalpivot::Int` can be set to `>= 0`. Default: `5`. The maximum number of global pivots to add in each iteration.
673-
- `nsearchglobalpivot::Int` can be set to `>= 0`. Default: `5`. This parameter is used for the default global pivot finder. Deprecated.
674-
- `tolmarginglobalsearch` can be set to `>= 1.0`. Seach global pivots where the interpolation error is larger than the tolerance by `tolmarginglobalsearch`. Default: `10.0`. This parameter is used for the default global pivot finder. Deprecated.
675684
- `strictlynested::Bool` determines whether to preserve partial nesting in the TCI algorithm. Default: `false`.
676685
- `checkbatchevaluatable::Bool` Check if the function `f` is batch evaluatable. Default: `false`.
686+
- `checkconvglobalpivot::Bool` Check if the global pivot finder is converged. Default: `true`. In the future, this will be set to `false` by default.
687+
688+
Arguments (deprecated):
689+
- `pivottolerance::Float64` is the tolerance for the pivot search. Deprecated.
690+
- `nsearchglobalpivot::Int` is the number of search points for the global pivot finder. Deprecated.
691+
- `tolmarginglobalsearch::Float64` is the tolerance for the global pivot finder. Deprecated.
677692
678693
Notes:
679694
- Set `tolerance` to be > 0 or `maxbonddim` to some reasonable value. Otherwise, convergence is not reachable.
@@ -700,7 +715,8 @@ function optimize!(
700715
nsearchglobalpivot::Int=5,
701716
tolmarginglobalsearch::Float64=10.0,
702717
strictlynested::Bool=false,
703-
checkbatchevaluatable::Bool=false
718+
checkbatchevaluatable::Bool=false,
719+
checkconvglobalpivot::Bool=true
704720
) where {ValueType}
705721
errors = Float64[]
706722
ranks = Int[]
@@ -785,9 +801,11 @@ function optimize!(
785801
end
786802

787803
# Find global pivots where the error is too large
804+
input = GlobalPivotSearchInput(tci)
788805
globalpivots = finder(
789-
tci, f, abstol,
790-
verbosity=verbosity
806+
input, f, abstol;
807+
verbosity=verbosity,
808+
rng=Random.default_rng()
791809
)
792810
addglobalpivots!(tci, globalpivots)
793811
push!(nglobalpivots, length(globalpivots))
@@ -803,7 +821,10 @@ function optimize!(
803821
flush(stdout)
804822
end
805823
if convergencecriterion(
806-
ranks, errors, nglobalpivots, abstol, maxbonddim, ncheckhistory
824+
ranks, errors,
825+
nglobalpivots,
826+
abstol, maxbonddim, ncheckhistory;
827+
checkconvglobalpivot=checkconvglobalpivot
807828
)
808829
break
809830
end

test/test_tensorci2.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using Test
22
import TensorCrossInterpolation as TCI
33
import TensorCrossInterpolation: rank, linkdims, TensorCI2, updatepivots!, addglobalpivots1sitesweep!, MultiIndex, evaluate, crossinterpolate2, pivoterror, tensortrain, optimize!
44
import Random
5+
import Random: AbstractRNG
56
import QuanticsGrids as QD
67

78
@testset "TensorCI2" begin
@@ -106,13 +107,14 @@ import QuanticsGrids as QD
106107
end
107108

108109
function (finder::CustomGlobalPivotFinder)(
109-
tci::TensorCI2{ValueType},
110+
input::TCI.GlobalPivotSearchInput{ValueType},
110111
f,
111112
abstol::Float64;
112-
verbosity::Int=0
113+
verbosity::Int=0,
114+
rng::AbstractRNG=Random.default_rng()
113115
)::Vector{MultiIndex} where {ValueType}
114-
L = length(tci.localdims)
115-
return [[rand(1:tci.localdims[p]) for p in 1:L] for _ in 1:finder.npivots]
116+
L = length(input.localdims)
117+
return [[rand(rng, 1:input.localdims[p]) for p in 1:L] for _ in 1:finder.npivots]
116118
end
117119

118120
@testset "custom global pivot finder" begin

0 commit comments

Comments
 (0)