Skip to content

Commit 3813134

Browse files
authored
Merge pull request #62 from tensor4all/61-allow-custom-global-pivot-finder
61 allow custom global pivot finder
2 parents 883a8fc + 62a8ccf commit 3813134

File tree

7 files changed

+379
-50
lines changed

7 files changed

+379
-50
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@ BitIntegers = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1"
88
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
11+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112

1213
[compat]
1314
BitIntegers = "0.3.5"
1415
EllipsisNotation = "1"
1516
QuadGK = "2.9"
17+
Random = "1.10.0"
1618
julia = "1.6"
1719

1820
[extras]

docs/src/documentation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Pages = ["tensorci1.jl", "indexset.jl", "sweepstrategies.jl"]
4141
### Tensor cross interpolation 2 (TCI2)
4242
```@autodocs
4343
Modules = [TensorCrossInterpolation]
44-
Pages = ["tensorci2.jl"]
44+
Pages = ["tensorci2.jl", "globalpivotfinder.jl"]
4545
```
4646

4747
### Integration

docs/src/index.md

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ end
242242

243243
`CachedFunction{T}` can wrap a function inheriting from `BatchEvaluator{T}`. In such cases, `CachedFunction{T}` caches the results of batch evaluation.
244244

245-
# Batch evaluation + parallelization
245+
## Batch evaluation + parallelization
246246
The batch evalution can be combined with parallelization using threads, MPI, etc.
247247
The following sample code use `Threads` to parallelize function evaluations.
248248
Note that the function evaluation for a single index set must be thread-safe.
@@ -350,3 +350,64 @@ end
350350
```
351351

352352
You can simply pass the wrapped function `parf` to `crossinterpolate2`.
353+
354+
## Global pivot finder
355+
A each TCI2 sweep, we can find the index sets with high interpolation error and add them to the TCI2 object.
356+
By default, we use a greedy search algorithm to find the index sets with high interpolation error.
357+
However, this may not be effective in some cases.
358+
In such cases, you can use a custom global pivot finder, which must inherit from `TCI.AbstractGlobalPivotFinder`.
359+
360+
Here's an example of a custom global pivot finder that randomly selects pivots:
361+
362+
```julia
363+
import TensorCrossInterpolation as TCI
364+
365+
struct CustomGlobalPivotFinder <: TCI.AbstractGlobalPivotFinder
366+
npivots::Int
367+
end
368+
369+
function (finder::CustomGlobalPivotFinder)(
370+
tci::TensorCI2{ValueType},
371+
f,
372+
abstol::Float64;
373+
verbosity::Int=0
374+
)::Vector{MultiIndex} where {ValueType}
375+
L = length(tci.localdims)
376+
return [[rand(1:tci.localdims[p]) for p in 1:L] for _ in 1:finder.npivots]
377+
end
378+
```
379+
380+
You can use this custom finder by passing it to the `optimize!` function:
381+
382+
```julia
383+
tci, ranks, errors = crossinterpolate2(
384+
Float64,
385+
f,
386+
localdims,
387+
firstpivots;
388+
globalpivotfinder=CustomGlobalPivotFinder(10) # Use custom finder that adds 10 random pivots
389+
)
390+
```
391+
392+
The default global pivot finder (`DefaultGlobalPivotFinder`) uses a greedy search algorithm to find index sets with high interpolation error. It has the following parameters:
393+
394+
- `nsearch`: Number of initial points to search from (default: 5)
395+
- `maxnglobalpivot`: Maximum number of pivots to add in each iteration (default: 5)
396+
- `tolmarginglobalsearch`: Search for pivots where the interpolation error is larger than the tolerance multiplied by this factor (default: 10.0)
397+
398+
You can customize these parameters by creating a `DefaultGlobalPivotFinder` instance:
399+
400+
```julia
401+
finder = TCI.DefaultGlobalPivotFinder(
402+
nsearch=10, # Search from 10 initial points
403+
maxnglobalpivot=3, # Add at most 3 pivots per iteration
404+
tolmarginglobalsearch=5.0 # Search for errors > 5 * tolerance
405+
)
406+
tci, ranks, errors = crossinterpolate2(
407+
Float64,
408+
f,
409+
localdims,
410+
firstpivots;
411+
globalpivotfinder=finder
412+
)
413+
```

src/TensorCrossInterpolation.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import Base: ==, +
1313
# To define iterators and element access for MCI, TCI and TT objects
1414
import Base: isempty, iterate, getindex, lastindex, broadcastable
1515
import Base: length, size, sum
16+
import Random
1617

1718
export crossinterpolate1, crossinterpolate2, optfirstpivot
1819
export tensortrain, TensorTrain, sitedims, evaluate
@@ -31,9 +32,10 @@ include("abstracttensortrain.jl")
3132
include("cachedtensortrain.jl")
3233
include("batcheval.jl")
3334
include("cachedfunction.jl")
35+
include("tensortrain.jl")
3436
include("tensorci1.jl")
37+
include("globalpivotfinder.jl")
3538
include("tensorci2.jl")
36-
include("tensortrain.jl")
3739
include("conversion.jl")
3840
include("integration.jl")
3941
include("contraction.jl")

src/globalpivotfinder.jl

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import Random: AbstractRNG, default_rng
2+
3+
"""
4+
GlobalPivotSearchInput{ValueType}
5+
6+
Input data structure for global pivot search algorithms.
7+
8+
# Fields
9+
- `localdims::Vector{Int}`: Dimensions of each tensor index
10+
- `current_tt::TensorTrain{ValueType,3}`: Current tensor train approximation
11+
- `maxsamplevalue::ValueType`: Maximum absolute value of the function
12+
- `Iset::Vector{Vector{MultiIndex}}`: Set of left indices
13+
- `Jset::Vector{Vector{MultiIndex}}`: Set of right indices
14+
"""
15+
struct GlobalPivotSearchInput{ValueType}
16+
localdims::Vector{Int}
17+
current_tt::TensorTrain{ValueType,3}
18+
maxsamplevalue::Float64
19+
Iset::Vector{Vector{MultiIndex}}
20+
Jset::Vector{Vector{MultiIndex}}
21+
22+
"""
23+
GlobalPivotSearchInput(
24+
localdims::Vector{Int},
25+
current_tt::TensorTrain{ValueType,3},
26+
maxsamplevalue::ValueType,
27+
Iset::Vector{Vector{MultiIndex}},
28+
Jset::Vector{Vector{MultiIndex}}
29+
) where {ValueType}
30+
31+
Construct a GlobalPivotSearchInput with the given fields.
32+
"""
33+
function GlobalPivotSearchInput{ValueType}(
34+
localdims::Vector{Int},
35+
current_tt::TensorTrain{ValueType,3},
36+
maxsamplevalue::Float64,
37+
Iset::Vector{Vector{MultiIndex}},
38+
Jset::Vector{Vector{MultiIndex}}
39+
) where {ValueType}
40+
new{ValueType}(
41+
localdims,
42+
current_tt,
43+
maxsamplevalue,
44+
Iset,
45+
Jset
46+
)
47+
end
48+
end
49+
50+
51+
"""
52+
AbstractGlobalPivotFinder
53+
54+
Abstract type for global pivot finders that search for indices with high interpolation error.
55+
"""
56+
abstract type AbstractGlobalPivotFinder end
57+
58+
"""
59+
(finder::AbstractGlobalPivotFinder)(
60+
input::GlobalPivotSearchInput{ValueType},
61+
f,
62+
abstol::Float64;
63+
verbosity::Int=0,
64+
rng::AbstractRNG=Random.default_rng()
65+
)::Vector{MultiIndex} where {ValueType}
66+
67+
Find global pivots using the given finder algorithm.
68+
69+
# Arguments
70+
- `input`: Input data for the search algorithm
71+
- `f`: Function to be interpolated
72+
- `abstol`: Absolute tolerance for the interpolation error
73+
- `verbosity`: Verbosity level (default: 0)
74+
- `rng`: Random number generator (default: Random.default_rng())
75+
76+
# Returns
77+
- `Vector{MultiIndex}`: Set of indices with high interpolation error
78+
"""
79+
function (finder::AbstractGlobalPivotFinder)(
80+
input::GlobalPivotSearchInput{ValueType},
81+
f,
82+
abstol::Float64;
83+
verbosity::Int=0,
84+
rng::AbstractRNG=Random.default_rng()
85+
)::Vector{MultiIndex} where {ValueType}
86+
error("find_global_pivots not implemented for $(typeof(finder))")
87+
end
88+
89+
"""
90+
DefaultGlobalPivotFinder
91+
92+
Default implementation of global pivot finder that uses random search.
93+
94+
# Fields
95+
- `nsearch::Int`: Number of initial points to search from
96+
- `maxnglobalpivot::Int`: Maximum number of pivots to add in each iteration
97+
- `tolmarginglobalsearch::Float64`: Search for pivots where the interpolation error is larger than the tolerance multiplied by this factor
98+
"""
99+
struct DefaultGlobalPivotFinder <: AbstractGlobalPivotFinder
100+
nsearch::Int
101+
maxnglobalpivot::Int
102+
tolmarginglobalsearch::Float64
103+
end
104+
105+
"""
106+
DefaultGlobalPivotFinder(;
107+
nsearch::Int=5,
108+
maxnglobalpivot::Int=5,
109+
tolmarginglobalsearch::Float64=10.0
110+
)
111+
112+
Construct a DefaultGlobalPivotFinder with the given parameters.
113+
"""
114+
function DefaultGlobalPivotFinder(;
115+
nsearch::Int=5,
116+
maxnglobalpivot::Int=5,
117+
tolmarginglobalsearch::Float64=10.0
118+
)
119+
return DefaultGlobalPivotFinder(nsearch, maxnglobalpivot, tolmarginglobalsearch)
120+
end
121+
122+
"""
123+
(finder::DefaultGlobalPivotFinder)(
124+
input::GlobalPivotSearchInput{ValueType},
125+
f,
126+
abstol::Float64;
127+
verbosity::Int=0,
128+
rng::AbstractRNG=Random.default_rng()
129+
)::Vector{MultiIndex} where {ValueType}
130+
131+
Find global pivots using random search.
132+
133+
# Arguments
134+
- `input`: Input data for the search algorithm
135+
- `f`: Function to be interpolated
136+
- `abstol`: Absolute tolerance for the interpolation error
137+
- `verbosity`: Verbosity level (default: 0)
138+
- `rng`: Random number generator (default: Random.default_rng())
139+
140+
# Returns
141+
- `Vector{MultiIndex}`: Set of indices with high interpolation error
142+
"""
143+
function (finder::DefaultGlobalPivotFinder)(
144+
input::GlobalPivotSearchInput{ValueType},
145+
f,
146+
abstol::Float64;
147+
verbosity::Int=0,
148+
rng::AbstractRNG=Random.default_rng()
149+
)::Vector{MultiIndex} where {ValueType}
150+
L = length(input.localdims)
151+
nsearch = finder.nsearch
152+
maxnglobalpivot = finder.maxnglobalpivot
153+
tolmarginglobalsearch = finder.tolmarginglobalsearch
154+
155+
# Generate random initial points
156+
initial_points = [[rand(rng, 1:input.localdims[p]) for p in 1:L] for _ in 1:nsearch]
157+
158+
# Find pivots with high interpolation error
159+
found_pivots = MultiIndex[]
160+
for point in initial_points
161+
# Perform local search from each initial point
162+
current_point = copy(point)
163+
best_error = 0.0
164+
best_point = copy(point)
165+
166+
# Local search
167+
for p in 1:L
168+
for v in 1:input.localdims[p]
169+
current_point[p] = v
170+
error = abs(f(current_point) - input.current_tt(current_point))
171+
if error > best_error
172+
best_error = error
173+
best_point = copy(current_point)
174+
end
175+
end
176+
current_point[p] = point[p] # Reset to original point
177+
end
178+
179+
# Add point if error is above threshold
180+
if best_error > abstol * tolmarginglobalsearch
181+
push!(found_pivots, best_point)
182+
end
183+
end
184+
185+
# Limit number of pivots
186+
if length(found_pivots) > maxnglobalpivot
187+
found_pivots = found_pivots[1:maxnglobalpivot]
188+
end
189+
190+
if verbosity > 0
191+
println("Found $(length(found_pivots)) global pivots")
192+
end
193+
194+
return found_pivots
195+
end

0 commit comments

Comments
 (0)