Skip to content

Commit a80ee50

Browse files
committed
Merge branch '47-better-global-error-estimate' into 'main'
Implement estimatetrueerror Closes #47 See merge request tensors4fields/TensorCrossInterpolation.jl!90
2 parents 9bb0836 + 2c39a96 commit a80ee50

File tree

6 files changed

+133
-1
lines changed

6 files changed

+133
-1
lines changed

docs/src/documentation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,6 @@ Pages = ["integration.jl"]
5353
## Helpers and utility methods
5454
```@autodocs
5555
Modules = [TensorCrossInterpolation]
56-
Pages = ["cachedfunction.jl", "batcheval.jl", "util.jl"]
56+
Pages = ["cachedfunction.jl", "batcheval.jl", "util.jl", "globalsearch.jl"]
5757
```
5858

docs/src/index.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,16 @@ tci, ranks, errors = TCI.crossinterpolate2(
117117
```
118118
This algorithm optimizes the given index set (in this case `[1, 2, 3, 4, 5]`) by searching for a maximum absolute value, alternating through the dimensions. If no starting point is given, `[1, 1, ...]` is used.
119119

120+
## Estiamte true interpolation error by random global search
121+
Since the TCI update algorithms are local, the true interpolation error is not known. However, the error can be estimated by global searches. This is implemented in the function `estimatetrueerror`:
122+
123+
```julia
124+
pivoterrors = TCI.estimatetrueerror(TCI.TensorTrain(tci), f)
125+
```
126+
127+
This function approximately estimates the error that would be reached by repeating a greedy search from a random initial point.
128+
The result is a vector of a found indexset and the corresponding error, sorted by error. The error is the maximum absolute difference between the function and the TT approximation.
129+
120130
## Caching
121131
During constructing a TCI, the function to be interpolated can be evaluated for the same index set multiple times.
122132
If an evaluation of the function to be interpolated is costly, i.e., takes more than 100 ns,

src/TensorCrossInterpolation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@ include("tensortrain.jl")
3636
include("conversion.jl")
3737
include("integration.jl")
3838
include("contraction.jl")
39+
include("globalsearch.jl")
3940

4041
end

src/globalsearch.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""
2+
function estimatetrueerror(
3+
tt::TensorTrain{ValueType,3}, f;
4+
nsearch::Int = 100,
5+
initialpoints::Union{Nothing,AbstractVector{MultiIndex}} = nothing,
6+
)::Vector{Tuple{MultiIndex,Float64}} where {ValueType}
7+
8+
Estimate the global error by comparing the exact function value and
9+
the TT approximation using a greedy algorithm and returns a unique list of pairs of the pivot and the error (error is the absolute difference between the exact function value and the TT approximation).
10+
On return, the list is sorted in descending order of the error.
11+
12+
Arguments:
13+
- `tt::TensorTrain{ValueType,3}`: The tensor train to be compared with `f`
14+
- `f`: The function to be compared with `tt`.
15+
- `nsearch::Int`: The number of initial points to be used in the search (defaults to 100).
16+
- `initialpoints::Union{Nothing,AbstractVector{MultiIndex}}`: The initial points to be used in the search (defaults to `nothing`). If `initialpoints` is not `nothing`, `nsearch` is ignored.
17+
"""
18+
function estimatetrueerror(
19+
tt::TensorTrain{ValueType,3}, f;
20+
nsearch::Int = 100,
21+
initialpoints::Union{Nothing,AbstractVector{MultiIndex}} = nothing,
22+
)::Vector{Tuple{MultiIndex,Float64}} where {ValueType}
23+
if nsearch <= 0 && initialpoints === nothing
24+
error("No search is performed")
25+
end
26+
nsearch >= 0 || error("nsearch must be non-negative")
27+
28+
if nsearch > 0 && initialpoints === nothing
29+
# Use random initial points
30+
initialpoints = [[rand(1:first(d)) for d in sitedims(tt)] for _ in 1:nsearch]
31+
end
32+
33+
ttcache = TTCache(tt)
34+
35+
pivoterror = [_floatingzone(ttcache, f, initp) for initp in initialpoints]
36+
37+
p = sortperm([e for (_, e) in pivoterror], rev=true)
38+
39+
return unique(pivoterror[p])
40+
end
41+
42+
43+
function _floatingzone(
44+
ttcache::TTCache{ValueType}, f, pivot
45+
)::Tuple{MultiIndex,Float64} where {ValueType}
46+
n = length(ttcache)
47+
48+
maxerror = abs(f(pivot) - ttcache(pivot))
49+
localdims = first.(sitedims(ttcache))
50+
51+
while true
52+
prev_maxerror = maxerror
53+
for ipos in 1:n
54+
exactdata = filltensor(
55+
ValueType,
56+
f,
57+
localdims,
58+
[pivot[1:ipos-1]],
59+
[pivot[ipos+1:end]],
60+
Val(1)
61+
)
62+
prediction = filltensor(
63+
ValueType,
64+
ttcache,
65+
localdims,
66+
[pivot[1:ipos-1]],
67+
[pivot[ipos+1:end]],
68+
Val(1)
69+
)
70+
err = vec(abs.(exactdata .- prediction))
71+
pivot[ipos] = argmax(err)
72+
maxerror = maximum(err)
73+
end
74+
75+
if maxerror == prev_maxerror
76+
break
77+
end
78+
end
79+
80+
return pivot, maxerror
81+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import TensorCrossInterpolation as TCI
22
using Test
33
using LinearAlgebra
44

5+
#==
56
include("test_with_aqua.jl")
67
include("test_with_jet.jl")
78
include("test_util.jl")
@@ -19,3 +20,5 @@ include("test_tensortrain.jl")
1920
include("test_conversion.jl")
2021
include("test_contraction.jl")
2122
include("test_integration.jl")
23+
==#
24+
include("test_globalsearch.jl")

test/test_globalsearch.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
using Test
2+
import TensorCrossInterpolation as TCI
3+
import TensorCrossInterpolation: crossinterpolate2, MultiIndex
4+
import Random
5+
import QuanticsGrids as QD
6+
7+
@testset "globalsearch" begin
8+
Random.seed!(1240)
9+
R = 20
10+
abstol = 1e-10
11+
12+
grid = QD.DiscretizedGrid{1}(R, (0.0,), (1.0,))
13+
14+
fx(x) = exp(-x) + 1e-3 * sin(1000 * x)
15+
f(bitlist::MultiIndex) = fx(QD.quantics_to_origcoord(grid, bitlist)[1])
16+
17+
abstol = 1e-4
18+
localdims = fill(2, R)
19+
firstpivots = [ones(Int, R), vcat(1, fill(2, R - 1))]
20+
tci, ranks, errors = crossinterpolate2(
21+
Float64,
22+
f,
23+
localdims,
24+
firstpivots;
25+
tolerance=abstol,
26+
maxbonddim=1,
27+
verbosity=1,
28+
normalizeerror=false,
29+
)
30+
31+
32+
pivoterrors = TCI.estimatetrueerror(TCI.TensorTrain(tci), f)
33+
34+
errors = [e for (_, e) in pivoterrors]
35+
@test all([abs(f(p) - tci(p)) for (p, _) in pivoterrors] .== errors)
36+
@test all(errors[1:end-1] .>= errors[2:end]) # check if errors are sorted in descending order
37+
end

0 commit comments

Comments
 (0)