|
242 | 242 |
|
243 | 243 | `CachedFunction{T}` can wrap a function inheriting from `BatchEvaluator{T}`. In such cases, `CachedFunction{T}` caches the results of batch evaluation.
|
244 | 244 |
|
245 |
| -# Batch evaluation + parallelization |
| 245 | +## Batch evaluation + parallelization |
246 | 246 | The batch evalution can be combined with parallelization using threads, MPI, etc.
|
247 | 247 | The following sample code use `Threads` to parallelize function evaluations.
|
248 | 248 | Note that the function evaluation for a single index set must be thread-safe.
|
|
350 | 350 | ```
|
351 | 351 |
|
352 | 352 | You can simply pass the wrapped function `parf` to `crossinterpolate2`.
|
| 353 | + |
| 354 | +## Global pivot finder |
| 355 | +A each TCI2 sweep, we can find the index sets with high interpolation error and add them to the TCI2 object. |
| 356 | +By default, we use a greedy search algorithm to find the index sets with high interpolation error. |
| 357 | +However, this may not be effective in some cases. |
| 358 | +In such cases, you can use a custom global pivot finder, which must inherit from `TCI.AbstractGlobalPivotFinder`. |
| 359 | + |
| 360 | +Here's an example of a custom global pivot finder that randomly selects pivots: |
| 361 | + |
| 362 | +```julia |
| 363 | +import TensorCrossInterpolation as TCI |
| 364 | + |
| 365 | +struct CustomGlobalPivotFinder <: TCI.AbstractGlobalPivotFinder |
| 366 | + npivots::Int |
| 367 | +end |
| 368 | + |
| 369 | +function (finder::CustomGlobalPivotFinder)( |
| 370 | + tci::TensorCI2{ValueType}, |
| 371 | + f, |
| 372 | + abstol::Float64; |
| 373 | + verbosity::Int=0 |
| 374 | +)::Vector{MultiIndex} where {ValueType} |
| 375 | + L = length(tci.localdims) |
| 376 | + return [[rand(1:tci.localdims[p]) for p in 1:L] for _ in 1:finder.npivots] |
| 377 | +end |
| 378 | +``` |
| 379 | + |
| 380 | +You can use this custom finder by passing it to the `optimize!` function: |
| 381 | + |
| 382 | +```julia |
| 383 | +tci, ranks, errors = crossinterpolate2( |
| 384 | + Float64, |
| 385 | + f, |
| 386 | + localdims, |
| 387 | + firstpivots; |
| 388 | + globalpivotfinder=CustomGlobalPivotFinder(10) # Use custom finder that adds 10 random pivots |
| 389 | +) |
| 390 | +``` |
| 391 | + |
| 392 | +The default global pivot finder (`DefaultGlobalPivotFinder`) uses a greedy search algorithm to find index sets with high interpolation error. It has the following parameters: |
| 393 | + |
| 394 | +- `nsearch`: Number of initial points to search from (default: 5) |
| 395 | +- `maxnglobalpivot`: Maximum number of pivots to add in each iteration (default: 5) |
| 396 | +- `tolmarginglobalsearch`: Search for pivots where the interpolation error is larger than the tolerance multiplied by this factor (default: 10.0) |
| 397 | + |
| 398 | +You can customize these parameters by creating a `DefaultGlobalPivotFinder` instance: |
| 399 | + |
| 400 | +```julia |
| 401 | +finder = TCI.DefaultGlobalPivotFinder( |
| 402 | + nsearch=10, # Search from 10 initial points |
| 403 | + maxnglobalpivot=3, # Add at most 3 pivots per iteration |
| 404 | + tolmarginglobalsearch=5.0 # Search for errors > 5 * tolerance |
| 405 | +) |
| 406 | +tci, ranks, errors = crossinterpolate2( |
| 407 | + Float64, |
| 408 | + f, |
| 409 | + localdims, |
| 410 | + firstpivots; |
| 411 | + globalpivotfinder=finder |
| 412 | +) |
| 413 | +``` |
0 commit comments