Skip to content

Commit ccf4bc9

Browse files
committed
renamed crossinterpolate->crossinterpolate1
also introduced deprecated crossinterpolate() for eventual removal.
1 parent 72f0307 commit ccf4bc9

File tree

5 files changed

+37
-10
lines changed

5 files changed

+37
-10
lines changed

src/TensorCrossInterpolation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import Base: ==, +
1212
import Base: isempty, iterate, getindex, lastindex, broadcastable
1313
import Base: length, size, sum
1414

15-
export crossinterpolate, crossinterpolate2, optfirstpivot
15+
export crossinterpolate1, crossinterpolate2, optfirstpivot
1616
export tensortrain
1717

1818
include("util.jl")

src/tensorci1.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
mutable struct TensorCI1{ValueType} <: AbstractTensorTrain{ValueType}
33
4-
Type that represents tensor cross interpolations created using the TCI1 algorithm. Users may want to create these using [`crossinterpolate`](@ref) rather than calling a constructor directly.
4+
Type that represents tensor cross interpolations created using the TCI1 algorithm. Users may want to create these using [`crossinterpolate1`](@ref) rather than calling a constructor directly.
55
"""
66
mutable struct TensorCI1{ValueType} <: AbstractTensorTrain{ValueType}
77
Iset::Vector{IndexSet{MultiIndex}}
@@ -469,7 +469,7 @@ end
469469

470470

471471
@doc raw"""
472-
function crossinterpolate(
472+
function crossinterpolate1(
473473
::Type{ValueType},
474474
f,
475475
localdims::Union{Vector{Int},NTuple{N,Int}},
@@ -504,7 +504,7 @@ Notes:
504504
505505
See also: [`optfirstpivot`](@ref), [`CachedFunction`](@ref), [`crossinterpolate2`](@ref)
506506
"""
507-
function crossinterpolate(
507+
function crossinterpolate1(
508508
::Type{ValueType},
509509
f,
510510
localdims::Union{Vector{Int},NTuple{N,Int}},
@@ -551,3 +551,30 @@ function crossinterpolate(
551551
errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0
552552
return tci, ranks, errors ./ errornormalization
553553
end
554+
555+
@doc raw"""
556+
function crossinterpolate(
557+
::Type{ValueType},
558+
f,
559+
localdims::Union{Vector{Int},NTuple{N,Int}},
560+
firstpivot::MultiIndex=ones(Int, length(localdims));
561+
tolerance::Float64=1e-8,
562+
maxiter::Int=200,
563+
sweepstrategy::Symbol=:backandforth,
564+
pivottolerance::Float64=1e-12,
565+
verbosity::Int=0,
566+
additionalpivots::Vector{MultiIndex}=MultiIndex[],
567+
normalizeerror::Bool=true
568+
) where {ValueType, N}
569+
570+
Deprecated, and only included for backward compatibility. Please use [`crossinterpolate1`](@ref) instead.
571+
"""
572+
function crossinterpolate(
573+
::Type{ValueType},
574+
f,
575+
localdims::Union{Vector{Int},NTuple{N,Int}},
576+
firstpivot::MultiIndex=ones(Int, length(localdims));
577+
kwargs...
578+
) where {ValueType, N}
579+
return crossinterpolate1(ValueType, f, localdims, firstpivot; kwargs...)
580+
end

test/test_conversion.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Test
22
import TensorCrossInterpolation: TensorCI1, TensorCI2, sitedims, linkdims, rank,
3-
addglobalpivot!, crossinterpolate, optimize!, MatrixACA, rrlu,
3+
addglobalpivot!, crossinterpolate1, optimize!, MatrixACA, rrlu,
44
nrows, ncols, evaluate, left, right
55

66
@testset "Conversion between rrLU and ACA" begin
@@ -40,7 +40,7 @@ end
4040
@test linkdims(tci2) == linkdims(tci1)
4141

4242
f(v) = (1.0 + 2.0im) ./ (sum(v .^ 2) + 1)
43-
tci1, ranks, errors = crossinterpolate(
43+
tci1, ranks, errors = crossinterpolate1(
4444
ComplexF64,
4545
f,
4646
fill(d, n),

test/test_tensorci1.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ import TensorCrossInterpolation: IndexSet, MultiIndex, CachedFunction, TensorCI1
9898
@test rank(tci) == iter
9999
end
100100

101-
tci2, ranks, errors = crossinterpolate(
101+
tci2, ranks, errors = crossinterpolate1(
102102
ValueType,
103103
f,
104104
fill(10, n),
@@ -111,7 +111,7 @@ import TensorCrossInterpolation: IndexSet, MultiIndex, CachedFunction, TensorCI1
111111
@test linkdims(tci) == linkdims(tci2)
112112
@test rank(tci) == rank(tci2)
113113

114-
tci3, ranks, errors = crossinterpolate(
114+
tci3, ranks, errors = crossinterpolate1(
115115
ValueType,
116116
f,
117117
fill(10, n),
@@ -124,7 +124,7 @@ import TensorCrossInterpolation: IndexSet, MultiIndex, CachedFunction, TensorCI1
124124
@test all(linkdims(tci3) .<= 200)
125125
@test rank(tci3) <= 200
126126

127-
tci4, ranks, errors = crossinterpolate(
127+
tci4, ranks, errors = crossinterpolate1(
128128
ValueType,
129129
f,
130130
fill(10, n),

test/test_tensortrain.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using Optim
77
g(v) = 1 / (sum(v .^ 2) + 1)
88
localdims = (6, 6, 6, 6)
99
tolerance = 1e-8
10-
tci, ranks, errors = TCI.crossinterpolate(Float64, g, localdims; tolerance=tolerance)
10+
tci, ranks, errors = TCI.crossinterpolate1(Float64, g, localdims; tolerance=tolerance)
1111
tt = TCI.TensorTrain(tci)
1212
@test rank(tci) == rank(tt)
1313
@test TCI.linkdims(tci) == TCI.linkdims(tt)

0 commit comments

Comments
 (0)