Skip to content

Commit 8b35171

Browse files
committed
Add description on global pivot finder to docs
1 parent d020ead commit 8b35171

File tree

2 files changed

+63
-2
lines changed

2 files changed

+63
-2
lines changed

docs/src/documentation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Pages = ["tensorci1.jl", "indexset.jl", "sweepstrategies.jl"]
4141
### Tensor cross interpolation 2 (TCI2)
4242
```@autodocs
4343
Modules = [TensorCrossInterpolation]
44-
Pages = ["tensorci2.jl"]
44+
Pages = ["tensorci2.jl", "globalpivotfinder.jl"]
4545
```
4646

4747
### Integration

docs/src/index.md

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ end
242242

243243
`CachedFunction{T}` can wrap a function inheriting from `BatchEvaluator{T}`. In such cases, `CachedFunction{T}` caches the results of batch evaluation.
244244

245-
# Batch evaluation + parallelization
245+
## Batch evaluation + parallelization
246246
The batch evalution can be combined with parallelization using threads, MPI, etc.
247247
The following sample code use `Threads` to parallelize function evaluations.
248248
Note that the function evaluation for a single index set must be thread-safe.
@@ -350,3 +350,64 @@ end
350350
```
351351

352352
You can simply pass the wrapped function `parf` to `crossinterpolate2`.
353+
354+
## Global pivot finder
355+
A each TCI2 sweep, we can find the index sets with high interpolation error and add them to the TCI2 object.
356+
By default, we use a greedy search algorithm to find the index sets with high interpolation error.
357+
However, this may not be effective in some cases.
358+
In such cases, you can use a custom global pivot finder, which must inherit from `TCI.AbstractGlobalPivotFinder`.
359+
360+
Here's an example of a custom global pivot finder that randomly selects pivots:
361+
362+
```julia
363+
import TensorCrossInterpolation as TCI
364+
365+
struct CustomGlobalPivotFinder <: TCI.AbstractGlobalPivotFinder
366+
npivots::Int
367+
end
368+
369+
function (finder::CustomGlobalPivotFinder)(
370+
tci::TensorCI2{ValueType},
371+
f,
372+
abstol::Float64;
373+
verbosity::Int=0
374+
)::Vector{MultiIndex} where {ValueType}
375+
L = length(tci.localdims)
376+
return [[rand(1:tci.localdims[p]) for p in 1:L] for _ in 1:finder.npivots]
377+
end
378+
```
379+
380+
You can use this custom finder by passing it to the `optimize!` function:
381+
382+
```julia
383+
tci, ranks, errors = crossinterpolate2(
384+
Float64,
385+
f,
386+
localdims,
387+
firstpivots;
388+
globalpivotfinder=CustomGlobalPivotFinder(10) # Use custom finder that adds 10 random pivots
389+
)
390+
```
391+
392+
The default global pivot finder (`DefaultGlobalPivotFinder`) uses a greedy search algorithm to find index sets with high interpolation error. It has the following parameters:
393+
394+
- `nsearch`: Number of initial points to search from (default: 5)
395+
- `maxnglobalpivot`: Maximum number of pivots to add in each iteration (default: 5)
396+
- `tolmarginglobalsearch`: Search for pivots where the interpolation error is larger than the tolerance multiplied by this factor (default: 10.0)
397+
398+
You can customize these parameters by creating a `DefaultGlobalPivotFinder` instance:
399+
400+
```julia
401+
finder = TCI.DefaultGlobalPivotFinder(
402+
nsearch=10, # Search from 10 initial points
403+
maxnglobalpivot=3, # Add at most 3 pivots per iteration
404+
tolmarginglobalsearch=5.0 # Search for errors > 5 * tolerance
405+
)
406+
tci, ranks, errors = crossinterpolate2(
407+
Float64,
408+
f,
409+
localdims,
410+
firstpivots;
411+
globalpivotfinder=finder
412+
)
413+
```

0 commit comments

Comments
 (0)