Skip to content

Commit b44e817

Browse files
author
Marc Ritter
committed
Merge branch 'add-crossinterpolate1-function' into 'main'
renamed crossinterpolate->crossinterpolate1 See merge request tensors4fields/TensorCrossInterpolation.jl!72
2 parents 72f0307 + 2c3681c commit b44e817

File tree

7 files changed

+41
-14
lines changed

7 files changed

+41
-14
lines changed

src/TensorCrossInterpolation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import Base: ==, +
1212
import Base: isempty, iterate, getindex, lastindex, broadcastable
1313
import Base: length, size, sum
1414

15-
export crossinterpolate, crossinterpolate2, optfirstpivot
15+
export crossinterpolate1, crossinterpolate2, optfirstpivot
1616
export tensortrain
1717

1818
include("util.jl")

src/tensorci1.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
mutable struct TensorCI1{ValueType} <: AbstractTensorTrain{ValueType}
33
4-
Type that represents tensor cross interpolations created using the TCI1 algorithm. Users may want to create these using [`crossinterpolate`](@ref) rather than calling a constructor directly.
4+
Type that represents tensor cross interpolations created using the TCI1 algorithm. Users may want to create these using [`crossinterpolate1`](@ref) rather than calling a constructor directly.
55
"""
66
mutable struct TensorCI1{ValueType} <: AbstractTensorTrain{ValueType}
77
Iset::Vector{IndexSet{MultiIndex}}
@@ -469,7 +469,7 @@ end
469469

470470

471471
@doc raw"""
472-
function crossinterpolate(
472+
function crossinterpolate1(
473473
::Type{ValueType},
474474
f,
475475
localdims::Union{Vector{Int},NTuple{N,Int}},
@@ -504,7 +504,7 @@ Notes:
504504
505505
See also: [`optfirstpivot`](@ref), [`CachedFunction`](@ref), [`crossinterpolate2`](@ref)
506506
"""
507-
function crossinterpolate(
507+
function crossinterpolate1(
508508
::Type{ValueType},
509509
f,
510510
localdims::Union{Vector{Int},NTuple{N,Int}},
@@ -551,3 +551,30 @@ function crossinterpolate(
551551
errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0
552552
return tci, ranks, errors ./ errornormalization
553553
end
554+
555+
@doc raw"""
556+
function crossinterpolate(
557+
::Type{ValueType},
558+
f,
559+
localdims::Union{Vector{Int},NTuple{N,Int}},
560+
firstpivot::MultiIndex=ones(Int, length(localdims));
561+
tolerance::Float64=1e-8,
562+
maxiter::Int=200,
563+
sweepstrategy::Symbol=:backandforth,
564+
pivottolerance::Float64=1e-12,
565+
verbosity::Int=0,
566+
additionalpivots::Vector{MultiIndex}=MultiIndex[],
567+
normalizeerror::Bool=true
568+
) where {ValueType, N}
569+
570+
Deprecated, and only included for backward compatibility. Please use [`crossinterpolate1`](@ref) instead.
571+
"""
572+
function crossinterpolate(
573+
::Type{ValueType},
574+
f,
575+
localdims::Union{Vector{Int},NTuple{N,Int}},
576+
firstpivot::MultiIndex=ones(Int, length(localdims));
577+
kwargs...
578+
) where {ValueType, N}
579+
return crossinterpolate1(ValueType, f, localdims, firstpivot; kwargs...)
580+
end

src/tensorci2.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ Notes:
632632
- By default, no caching takes place. Use the [`CachedFunction`](@ref) wrapper if your function is expensive to evaluate.
633633
634634
635-
See also: [`crossinterpolate2`](@ref), [`optfirstpivot`](@ref), [`CachedFunction`](@ref), [`crossinterpolate`](@ref)
635+
See also: [`crossinterpolate2`](@ref), [`optfirstpivot`](@ref), [`CachedFunction`](@ref), [`crossinterpolate1`](@ref)
636636
"""
637637
function optimize!(
638638
tci::TensorCI2{ValueType},
@@ -871,7 +871,7 @@ Notes:
871871
- By default, no caching takes place. Use the [`CachedFunction`](@ref) wrapper if your function is expensive to evaluate.
872872
873873
874-
See also: [`optimize!`](@ref), [`optfirstpivot`](@ref), [`CachedFunction`](@ref), [`crossinterpolate`](@ref)
874+
See also: [`optimize!`](@ref), [`optfirstpivot`](@ref), [`CachedFunction`](@ref), [`crossinterpolate1`](@ref)
875875
"""
876876
function crossinterpolate2(
877877
::Type{ValueType},

src/util.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ Optimize the first pivot for a tensor cross interpolation.
6969
7070
Arguments:
7171
- `f` is function to be interpolated.
72-
- `localdims::Union{Vector{Int},NTuple{N,Int}}` determines the local dimensions of the function parameters (see [`crossinterpolate`](@ref)).
72+
- `localdims::Union{Vector{Int},NTuple{N,Int}}` determines the local dimensions of the function parameters (see [`crossinterpolate1`](@ref)).
7373
- `fistpivot::MultiIndex=ones(Int, length(localdims))` is the starting point for the optimization. It is advantageous to choose it close to a global maximum of the function.
7474
- `maxsweep` is the maximum number of optimization sweeps. Default: `1000`.
7575
76-
See also: [`crossinterpolate`](@ref)
76+
See also: [`crossinterpolate1`](@ref)
7777
"""
7878
function optfirstpivot(
7979
f,

test/test_conversion.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Test
22
import TensorCrossInterpolation: TensorCI1, TensorCI2, sitedims, linkdims, rank,
3-
addglobalpivot!, crossinterpolate, optimize!, MatrixACA, rrlu,
3+
addglobalpivot!, crossinterpolate1, optimize!, MatrixACA, rrlu,
44
nrows, ncols, evaluate, left, right
55

66
@testset "Conversion between rrLU and ACA" begin
@@ -40,7 +40,7 @@ end
4040
@test linkdims(tci2) == linkdims(tci1)
4141

4242
f(v) = (1.0 + 2.0im) ./ (sum(v .^ 2) + 1)
43-
tci1, ranks, errors = crossinterpolate(
43+
tci1, ranks, errors = crossinterpolate1(
4444
ComplexF64,
4545
f,
4646
fill(d, n),

test/test_tensorci1.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ import TensorCrossInterpolation: IndexSet, MultiIndex, CachedFunction, TensorCI1
9898
@test rank(tci) == iter
9999
end
100100

101-
tci2, ranks, errors = crossinterpolate(
101+
tci2, ranks, errors = crossinterpolate1(
102102
ValueType,
103103
f,
104104
fill(10, n),
@@ -111,7 +111,7 @@ import TensorCrossInterpolation: IndexSet, MultiIndex, CachedFunction, TensorCI1
111111
@test linkdims(tci) == linkdims(tci2)
112112
@test rank(tci) == rank(tci2)
113113

114-
tci3, ranks, errors = crossinterpolate(
114+
tci3, ranks, errors = crossinterpolate1(
115115
ValueType,
116116
f,
117117
fill(10, n),
@@ -124,7 +124,7 @@ import TensorCrossInterpolation: IndexSet, MultiIndex, CachedFunction, TensorCI1
124124
@test all(linkdims(tci3) .<= 200)
125125
@test rank(tci3) <= 200
126126

127-
tci4, ranks, errors = crossinterpolate(
127+
tci4, ranks, errors = crossinterpolate1(
128128
ValueType,
129129
f,
130130
fill(10, n),

test/test_tensortrain.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using Optim
77
g(v) = 1 / (sum(v .^ 2) + 1)
88
localdims = (6, 6, 6, 6)
99
tolerance = 1e-8
10-
tci, ranks, errors = TCI.crossinterpolate(Float64, g, localdims; tolerance=tolerance)
10+
tci, ranks, errors = TCI.crossinterpolate1(Float64, g, localdims; tolerance=tolerance)
1111
tt = TCI.TensorTrain(tci)
1212
@test rank(tci) == rank(tt)
1313
@test TCI.linkdims(tci) == TCI.linkdims(tt)

0 commit comments

Comments
 (0)