Skip to content

Commit 8c462dc

Browse files
authored
Merge pull request #44 from tensor4all/35-deprecate-pivottolerance-in-TCI2
deprecated pivottolerance option in TCI2
2 parents 158b599 + d8a789b commit 8c462dc

File tree

2 files changed

+35
-10
lines changed

2 files changed

+35
-10
lines changed

src/tensorci2.jl

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,6 @@ end
597597
tci::TensorCI2{ValueType},
598598
f;
599599
tolerance::Float64=1e-8,
600-
pivottolerance::Float64=tolerance,
601600
maxbonddim::Int=typemax(Int),
602601
maxiter::Int=200,
603602
sweepstrategy::Symbol=:backandforth,
@@ -644,8 +643,8 @@ See also: [`crossinterpolate2`](@ref), [`optfirstpivot`](@ref), [`CachedFunction
644643
function optimize!(
645644
tci::TensorCI2{ValueType},
646645
f;
647-
tolerance::Float64=1e-8,
648-
pivottolerance::Float64=tolerance,
646+
tolerance::Union{Float64, Nothing}=nothing,
647+
pivottolerance::Union{Float64, Nothing}=nothing,
649648
maxbonddim::Int=typemax(Int),
650649
maxiter::Int=20,
651650
sweepstrategy::Symbol=:backandforth,
@@ -663,6 +662,7 @@ function optimize!(
663662
errors = Float64[]
664663
ranks = Int[]
665664
nglobalpivots = Int[]
665+
local tol::Float64
666666

667667
if checkbatchevaluatable && !(f isa BatchEvaluator)
668668
error("Function `f` is not batch evaluatable")
@@ -675,9 +675,23 @@ function optimize!(
675675
error("nsearchglobalpivot < maxnglobalpivot!")
676676
end
677677

678+
# Deprecate the pivottolerance option
679+
if !isnothing(pivottolerance)
680+
if !isnothing(tolerance) && (tolerance != pivottolerance)
681+
throw(ArgumentError("Got different values for pivottolerance and tolerance in optimize!(TCI2). For TCI2, both of these options have the same meaning. Please assign only `tolerance`."))
682+
else
683+
@warn "The option `pivottolerance` of `optimize!(tci::TensorCI2, f)` is deprecated. Please update your code to use `tolerance`, as `pivottolerance` will be removed in the future."
684+
tol = pivottolerance
685+
end
686+
elseif !isnothing(tolerance)
687+
tol = tolerance
688+
else # pivottolerance == tolerance == nothing, therefore set tol to default value
689+
tol = 1e-8
690+
end
691+
678692
tstart = time_ns()
679693

680-
if maxbonddim >= typemax(Int) && tolerance <= 0
694+
if maxbonddim >= typemax(Int) && tol <= 0
681695
throw(ArgumentError(
682696
"Specify either tolerance > 0 or some maxbonddim; otherwise, the convergence criterion is not reachable!"
683697
))
@@ -686,7 +700,7 @@ function optimize!(
686700
globalpivots = MultiIndex[]
687701
for iter in 1:maxiter
688702
errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0
689-
abstol = pivottolerance * errornormalization;
703+
abstol = tol * errornormalization;
690704

691705
if verbosity > 1
692706
println(" Walltime $(1e-9*(time_ns() - tstart)) sec: starting 2site sweep")
@@ -753,7 +767,7 @@ function optimize!(
753767
# or the bond dimension exceeds maxbonddim
754768
# (2) Compute site tensors
755769
errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0
756-
abstol = pivottolerance * errornormalization;
770+
abstol = tol * errornormalization;
757771
sweep1site!(
758772
tci,
759773
f,
@@ -792,7 +806,7 @@ function sweep2site!(
792806
extraIset = tci.Iset_history[end]
793807
extraJset = tci.Jset_history[end]
794808
end
795-
809+
796810
push!(tci.Iset_history, deepcopy(tci.Iset))
797811
push!(tci.Jset_history, deepcopy(tci.Jset))
798812

@@ -940,4 +954,3 @@ function searchglobalpivots(
940954

941955
return [p for (_,p) in pivots]
942956
end
943-

test/test_tensorci2.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Test
22
import TensorCrossInterpolation as TCI
3-
import TensorCrossInterpolation: rank, linkdims, TensorCI2, updatepivots!, addglobalpivots1sitesweep!, MultiIndex, evaluate, crossinterpolate2, pivoterror, tensortrain
3+
import TensorCrossInterpolation: rank, linkdims, TensorCI2, updatepivots!, addglobalpivots1sitesweep!, MultiIndex, evaluate, crossinterpolate2, pivoterror, tensortrain, optimize!
44
import Random
55
import QuanticsGrids as QD
66

@@ -166,6 +166,19 @@ import QuanticsGrids as QD
166166
@test linkdims(tci) == fill(1, n - 1)
167167
end
168168

169+
@testset "TCI2 errors and warnings" begin
170+
n = 5
171+
f(v) = 1.0 ./ (sum(v .^ 2) + 1)
172+
173+
@test_throws ArgumentError crossinterpolate2(Float64, f, fill(2, n); tolerance=1e-9, pivottolerance=1e-2)
174+
@test_throws ArgumentError crossinterpolate2(Float64, f, fill(2, n); tolerance=0.0)
175+
176+
tci, = crossinterpolate2(Float64, f, fill(2, n); tolerance=0.1)
177+
@test_throws ArgumentError optimize!(tci, f; pivottolerance = 0.1, tolerance = 0.01)
178+
@test_throws ArgumentError optimize!(tci, f; tolerance = 0.0)
179+
@test_logs (:warn, "The option `pivottolerance` of `optimize!(tci::TensorCI2, f)` is deprecated. Please update your code to use `tolerance`, as `pivottolerance` will be removed in the future.") optimize!(tci, f; pivottolerance = 0.1)
180+
end
181+
169182
@testset "Lorentz MPS with ValueType=$(typeof(coeff)), pivotsearch=$pivotsearch" for coeff in [1.0, 0.5 - 1.0im], pivotsearch in [:full, :rook]
170183
n = 5
171184
f(v) = coeff ./ (sum(v .^ 2) + 1)
@@ -206,7 +219,6 @@ import QuanticsGrids as QD
206219
fill(10, n),
207220
[ones(Int, n)];
208221
tolerance=1e-8,
209-
pivottolerance=1e-8,
210222
maxiter=8,
211223
sweepstrategy=:forward,
212224
pivotsearch=pivotsearch

0 commit comments

Comments
 (0)