Skip to content

Commit ae8967a

Browse files
committed
Some clean up for global search, normalizerror now defaults to true in integrate()
1 parent a9aa449 commit ae8967a

File tree

3 files changed

+40
-77
lines changed

3 files changed

+40
-77
lines changed

src/globalsearch.jl

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function estimatetrueerror(
3232

3333
ttcache = TTCache(tt)
3434

35-
pivoterror = [_floatingzone(ttcache, f, initp) for initp in initialpoints]
35+
pivoterror = [_floatingzone(ttcache, f; initp=initp) for initp in initialpoints]
3636

3737
p = sortperm([e for (_, e) in pivoterror], rev=true)
3838

@@ -41,14 +41,25 @@ end
4141

4242

4343
function _floatingzone(
44-
ttcache::TTCache{ValueType}, f, pivot
44+
ttcache::TTCache{ValueType}, f;
45+
earlystoptol::Float64 = typemax(Float64),
46+
nsweeps=typemax(Int), initp::Union{Nothing,MultiIndex} = nothing
4547
)::Tuple{MultiIndex,Float64} where {ValueType}
48+
nsweeps > 0 || error("nsweeps should be positive!")
49+
50+
localdims = first.(sitedims(ttcache))
51+
4652
n = length(ttcache)
4753

54+
if initp === nothing
55+
pivot = [rand(1:d) for d in localdims]
56+
else
57+
pivot = initp
58+
end
59+
4860
maxerror = abs(f(pivot) - ttcache(pivot))
49-
localdims = first.(sitedims(ttcache))
5061

51-
while true
62+
for isweep in 1:nsweeps
5263
prev_maxerror = maxerror
5364
for ipos in 1:n
5465
exactdata = filltensor(
@@ -72,10 +83,28 @@ function _floatingzone(
7283
maxerror = maximum(err)
7384
end
7485

75-
if maxerror == prev_maxerror
86+
if maxerror == prev_maxerror || maxerror > earlystoptol # early stop
7687
break
7788
end
7889
end
7990

8091
return pivot, maxerror
92+
end
93+
94+
95+
function fillsitetensors!(
96+
tci::TensorCI2{ValueType}, f) where {ValueType}
97+
for b in 1:length(tci)
98+
setsitetensor!(tci, f, b)
99+
end
100+
nothing
101+
end
102+
103+
104+
function _sanitycheck(tci::TensorCI2{ValueType})::Bool where {ValueType}
105+
for b in 1:length(tci)-1
106+
length(tci.Iset[b+1]) == length(tci.Jset[b]) || error("Pivot matrix at bond $(b) is not square!")
107+
end
108+
109+
return true
81110
end

src/integration.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ function integrate(
2323
a::Vector{ValueType},
2424
b::Vector{ValueType};
2525
GKorder::Int=15,
26-
normalizeerror=false,
2726
kwargs...
2827
) where {ValueType}
2928
if iseven(GKorder)
@@ -52,7 +51,6 @@ function integrate(
5251
F,
5352
localdims;
5453
nsearchglobalpivot=10,
55-
normalizeerror,
5654
kwargs...
5755
)
5856

src/tensorci2.jl

Lines changed: 6 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -704,10 +704,11 @@ function optimize!(
704704
sweepstrategy=sweepstrategy,
705705
fillsitetensors=true
706706
)
707-
if verbosity > 0 && length(globalpivots) > 0
708-
nrejections = length([p for p in globalpivots if abs(evaluate(tci, p) - f(p)) > abstol])
707+
if verbosity > 0 && length(globalpivots) > 0 && mod(iter, loginterval) == 0
708+
abserr = [abs(evaluate(tci, p) - f(p)) for p in globalpivots]
709+
nrejections = length(abserr .> abstol)
709710
if nrejections > 0
710-
println(" Rejected $(nrejections) global pivots added in the previous iteration")
711+
println(" Rejected $(nrejections) global pivots added in the previous iteration, errors are $(abserr)")
711712
flush(stdout)
712713
end
713714
end
@@ -916,8 +917,9 @@ function searchglobalpivots(
916917
end
917918

918919
pivots = Dict{Float64,MultiIndex}()
920+
ttcache = TTCache(tci)
919921
for _ in 1:nsearch
920-
pivot, error = _floatingzone(tci, f, 10 * abstol)
922+
pivot, error = _floatingzone(ttcache, f; earlystoptol = 10 * abstol, nsweeps=100)
921923
if error > abstol
922924
pivots[error] = pivot
923925
end
@@ -941,69 +943,3 @@ function searchglobalpivots(
941943
return [p for (_,p) in pivots]
942944
end
943945

944-
945-
function _floatingzone(
946-
tci::TensorCI2{ValueType}, f, abstol;
947-
nsweeps=100
948-
)::Tuple{MultiIndex,Float64} where {ValueType}
949-
nsweeps > 0 || error("nsweeps should be positive!")
950-
951-
localdims = tci.localdims
952-
953-
n = length(tci)
954-
955-
ttcache = TTCache(tci)
956-
957-
pivot = [rand(1:d) for d in localdims]
958-
959-
maxerror = abs(f(pivot) - ttcache(pivot))
960-
961-
for isweep in 1:nsweeps
962-
prev_maxerror = maxerror
963-
for ipos in 1:n
964-
exactdata = filltensor(
965-
ValueType,
966-
f,
967-
tci.localdims,
968-
[pivot[1:ipos-1]],
969-
[pivot[ipos+1:end]],
970-
Val(1)
971-
)
972-
prediction = filltensor(
973-
ValueType,
974-
ttcache,
975-
tci.localdims,
976-
[pivot[1:ipos-1]],
977-
[pivot[ipos+1:end]],
978-
Val(1)
979-
)
980-
err = vec(abs.(exactdata .- prediction))
981-
pivot[ipos] = argmax(err)
982-
maxerror = maximum(err)
983-
end
984-
985-
if maxerror == prev_maxerror || maxerror > abstol # early stop
986-
break
987-
end
988-
end
989-
990-
return pivot, maxerror
991-
end
992-
993-
994-
function fillsitetensors!(
995-
tci::TensorCI2{ValueType}, f) where {ValueType}
996-
for b in 1:length(tci)
997-
setsitetensor!(tci, f, b)
998-
end
999-
nothing
1000-
end
1001-
1002-
1003-
function _sanitycheck(tci::TensorCI2{ValueType})::Bool where {ValueType}
1004-
for b in 1:length(tci)-1
1005-
length(tci.Iset[b+1]) == length(tci.Jset[b]) || error("Pivot matrix at bond $(b) is not square!")
1006-
end
1007-
1008-
return true
1009-
end

0 commit comments

Comments
 (0)