|
| 1 | +using Test |
| 2 | +import TensorCrossInterpolation as TCI |
| 3 | +import TensorCrossInterpolation: rank, linkdims, TensorCI2, updatepivots!, addglobalpivots1sitesweep!, MultiIndex, evaluate, SweepStrategies, crossinterpolate2, pivoterror, tensortrain |
| 4 | +import Random |
| 5 | +import QuanticsGrids as QD |
| 6 | + |
| 7 | +@testset "TensorCI2" begin |
| 8 | + #== |
| 9 | + @testset "kronecker util function" begin |
| 10 | + multiset = [collect(1:5) for _ in 1:5] |
| 11 | + localdim = 4 |
| 12 | + localset = collect(1:localdim) |
| 13 | + |
| 14 | + c = TCI.kronecker(multiset, localdim) |
| 15 | + for (i, ci) in enumerate(c) |
| 16 | + @test ci[1:5] == collect(1:5) |
| 17 | + @test ci[6] in localset |
| 18 | + end |
| 19 | + |
| 20 | + d = TCI.kronecker(localdim, multiset) |
| 21 | + for (i, di) in enumerate(d) |
| 22 | + @test di[1] in localset |
| 23 | + @test di[2:6] == collect(1:5) |
| 24 | + end |
| 25 | + end |
| 26 | + |
| 27 | + @testset "trivial MPS(exp): pivotsearch=$pivotsearch" for pivotsearch in [:full, :rook] |
| 28 | + # f(x) = exp(-x) |
| 29 | + Random.seed!(1240) |
| 30 | + R = 8 |
| 31 | + abstol = 1e-4 |
| 32 | + |
| 33 | + grid = QD.DiscretizedGrid{1}(R, (0.0,), (1.0,)) |
| 34 | + |
| 35 | + #index_to_x(i) = (i - 1) / 2^R # x ∈ [0, 1) |
| 36 | + fx(x) = exp(-x) |
| 37 | + f(bitlist::MultiIndex) = fx(QD.quantics_to_origcoord(grid, bitlist)[1]) |
| 38 | + |
| 39 | + localdims = fill(2, R) |
| 40 | + firstpivots = [ones(Int, R), vcat(1, fill(2, R - 1))] |
| 41 | + tci, ranks, errors = crossinterpolate2( |
| 42 | + Float64, |
| 43 | + f, |
| 44 | + localdims, |
| 45 | + firstpivots; |
| 46 | + tolerance=abstol, |
| 47 | + maxbonddim=1, |
| 48 | + maxiter=2, |
| 49 | + loginterval=1, |
| 50 | + verbosity=0, |
| 51 | + normalizeerror=false |
| 52 | + ) |
| 53 | + |
| 54 | + @test all(TCI.linkdims(tci) .== 1) |
| 55 | + |
| 56 | + for x in [0.1, 0.3, 0.6, 0.9] |
| 57 | + indexset = QD.origcoord_to_quantics( |
| 58 | + grid, (x,) |
| 59 | + ) |
| 60 | + @test abs(TCI.evaluate(tci, indexset) - f(indexset)) < abstol |
| 61 | + end |
| 62 | + |
| 63 | + end |
| 64 | + |
| 65 | + @testset "trivial MPS" begin |
| 66 | + n = 5 |
| 67 | + f(v) = sum(v) * 0.5 |
| 68 | + |
| 69 | + tci = TensorCI2{Float64}(fill(2, n)) |
| 70 | + @test length(tci) == n |
| 71 | + @test rank(tci) == 0 |
| 72 | + @test linkdims(tci) == fill(0, n - 1) |
| 73 | + for i in 1:n |
| 74 | + @test isempty(tci.Iset[i]) |
| 75 | + @test isempty(tci.Jset[i]) |
| 76 | + end |
| 77 | + |
| 78 | + tci = TCI.TensorCI2{Float64}(f, fill(2, n), [fill(1, n)]) |
| 79 | + @test length(tci) == n |
| 80 | + @test rank(tci) == 1 |
| 81 | + @test linkdims(tci) == fill(1, n - 1) |
| 82 | + end |
| 83 | + |
| 84 | + @testset "Lorentz MPS with ValueType=$(typeof(coeff)), pivotsearch=$pivotsearch" for coeff in [1.0, 0.5 - 1.0im], pivotsearch in [:full, :rook] |
| 85 | + n = 5 |
| 86 | + f(v) = coeff ./ (sum(v .^ 2) + 1) |
| 87 | + |
| 88 | + ValueType = typeof(coeff) |
| 89 | + |
| 90 | + tci = TensorCI2{ValueType}(f, fill(10, n)) |
| 91 | + |
| 92 | + @test linkdims(tci) == ones(n - 1) |
| 93 | + @test rank(tci) == 1 |
| 94 | + @test length(tci.Iset[1]) == 1 |
| 95 | + @test length(tci.Jset[end]) == 1 |
| 96 | + |
| 97 | + for p in 1:n-1 |
| 98 | + updatepivots!(tci, p, f, true; reltol=1e-8, maxbonddim=2, pivotsearch) |
| 99 | + end |
| 100 | + @test linkdims(tci) == fill(2, n - 1) |
| 101 | + @test rank(tci) == 2 |
| 102 | + @test length(tci.Iset[1]) == 1 |
| 103 | + @test length(tci.Jset[end]) == 1 |
| 104 | + |
| 105 | + globalpivot = [2, 9, 10, 5, 7] |
| 106 | + addglobalpivots1sitesweep!(tci, f, [globalpivot], reltol=1e-12) |
| 107 | + @test linkdims(tci) == fill(3, n - 1) |
| 108 | + @test rank(tci) == 3 |
| 109 | + @test length(tci.Iset[1]) == 1 |
| 110 | + @test length(tci.Jset[end]) == 1 |
| 111 | + |
| 112 | + for iter in 4:20 |
| 113 | + for p in 1:n-1 |
| 114 | + updatepivots!(tci, p, f, true; reltol=1e-8, pivotsearch) |
| 115 | + end |
| 116 | + end |
| 117 | + |
| 118 | + tci2, ranks, errors = crossinterpolate2( |
| 119 | + ValueType, |
| 120 | + f, |
| 121 | + fill(10, n), |
| 122 | + [ones(Int, n)]; |
| 123 | + tolerance=1e-8, |
| 124 | + pivottolerance=1e-8, |
| 125 | + maxiter=8, |
| 126 | + sweepstrategy=SweepStrategies.forward, |
| 127 | + pivotsearch=pivotsearch |
| 128 | + ) |
| 129 | + |
| 130 | + #@test linkdims(tci) == linkdims(tci2) Too strict |
| 131 | + @test rank(tci) == rank(tci2) |
| 132 | + |
| 133 | + tci3, ranks, errors = crossinterpolate2( |
| 134 | + ValueType, |
| 135 | + f, |
| 136 | + fill(10, n), |
| 137 | + [ones(Int, n)]; |
| 138 | + tolerance=1e-12, |
| 139 | + maxiter=200, |
| 140 | + pivotsearch |
| 141 | + ) |
| 142 | + |
| 143 | + @test pivoterror(tci3) <= 2e-12 |
| 144 | + @test all(linkdims(tci3) .<= 200) |
| 145 | + @test rank(tci3) <= 200 |
| 146 | + |
| 147 | + initialpivots = [ |
| 148 | + [1, 1, 1, 1, 1], |
| 149 | + [10, 8, 10, 4, 4], |
| 150 | + [5, 4, 8, 9, 3], |
| 151 | + [7, 7, 10, 5, 9], |
| 152 | + [7, 7, 10, 5, 9] |
| 153 | + ] |
| 154 | + |
| 155 | + tci4, ranks, errors = crossinterpolate2( |
| 156 | + ValueType, |
| 157 | + f, |
| 158 | + fill(10, n), |
| 159 | + initialpivots; |
| 160 | + tolerance=1e-12, |
| 161 | + maxiter=200, |
| 162 | + pivotsearch |
| 163 | + ) |
| 164 | + |
| 165 | + @test pivoterror(tci4) <= 2e-12 |
| 166 | + @test all(linkdims(tci4) .<= 200) |
| 167 | + @test rank(tci4) <= 200 |
| 168 | + |
| 169 | + tt3 = tensortrain(tci3) |
| 170 | + |
| 171 | + for v in Iterators.product([1:3 for p in 1:n]...) |
| 172 | + value = evaluate(tci3, [i for i in v]) |
| 173 | + @test value ≈ prod([tt3[p][:, v[p], :] for p in eachindex(v)])[1] |
| 174 | + @test value ≈ f(v) |
| 175 | + end |
| 176 | + end |
| 177 | + ==# |
| 178 | + |
| 179 | + @testset "insert_global_pivots: pivotsearch=$pivotsearch" for pivotsearch in [:full], partialnesting in [false, true] |
| 180 | + Random.seed!(1234) |
| 181 | + |
| 182 | + R = 20 |
| 183 | + abstol = 1e-4 |
| 184 | + grid = QD.DiscretizedGrid{1}(R, (0.0,), (1.0,)) |
| 185 | + |
| 186 | + rindex = [rand(1:2, R) for _ in 1:100] |
| 187 | + |
| 188 | + f(bitlist) = fx(QD.quantics_to_origcoord(grid, bitlist)[1]) |
| 189 | + rpoint = Float64[QD.quantics_to_origcoord(grid, r)[1] for r in rindex] |
| 190 | + |
| 191 | + function fx(x) |
| 192 | + res = exp(-10 * x) |
| 193 | + for r in rpoint |
| 194 | + res += abs(x - r) < 1e-5 ? 2 * abstol : 0.0 |
| 195 | + end |
| 196 | + res |
| 197 | + end |
| 198 | + |
| 199 | + localdims = fill(2, R) |
| 200 | + firstpivot = ones(Int, R) |
| 201 | + tci, ranks, errors = crossinterpolate2( |
| 202 | + Float64, |
| 203 | + f, |
| 204 | + localdims, |
| 205 | + [firstpivot]; |
| 206 | + tolerance=abstol, |
| 207 | + maxbonddim=1000, |
| 208 | + maxiter=20, |
| 209 | + loginterval=1, |
| 210 | + verbosity=0, |
| 211 | + normalizeerror=false, |
| 212 | + pivotsearch=pivotsearch, |
| 213 | + partialnesting=true |
| 214 | + ) |
| 215 | + #@show sum(abs.([TCI.evaluate(tci, r) - f(r) for r in rindex]) .> abstol) |
| 216 | + |
| 217 | + TCI.addglobalpivots2sitesweep!( |
| 218 | + tci, f, rindex, |
| 219 | + tolerance=abstol, |
| 220 | + normalizeerror=false, |
| 221 | + maxbonddim=1000, |
| 222 | + pivotsearch=pivotsearch, |
| 223 | + verbosity=1, |
| 224 | + partialnesting=partialnesting, |
| 225 | + ntry = (!partialnesting && pivotsearch == :full) ? 1 : 10 |
| 226 | + ) |
| 227 | + @test sum(abs.([TCI.evaluate(tci, r) - f(r) for r in rindex]) .> abstol) == 0 |
| 228 | + end |
| 229 | + |
| 230 | + #== |
| 231 | + @testset "globalsearch" begin |
| 232 | + Random.seed!(1234) |
| 233 | + |
| 234 | + n = 10 |
| 235 | + fx(x) = exp(-10 * x) * sin(2 * pi * 100 * x^1.1) # Nasty function |
| 236 | + f(bitlist) = fx(QD.quantics_to_origcoord(grid, bitlist)[1]) |
| 237 | + grid = QD.DiscretizedGrid{1}(n, (0.0,), (1.0,)) |
| 238 | + |
| 239 | + localdims = fill(2, n) |
| 240 | + |
| 241 | + # This checks only that the function runs without error |
| 242 | + tci, ranks, errors = crossinterpolate2( |
| 243 | + Float64, |
| 244 | + f, |
| 245 | + localdims, |
| 246 | + tolerance=1e-12, |
| 247 | + maxbonddim=100, |
| 248 | + maxiter=100, |
| 249 | + nsearchglobalpivot=10 |
| 250 | + ) |
| 251 | + |
| 252 | + @test errors[end] < 1e-10 |
| 253 | + end |
| 254 | + |
| 255 | + |
| 256 | + @testset "crossinterpolate2_ttcache" begin |
| 257 | + ValueType = Float64 |
| 258 | + |
| 259 | + N = 4 |
| 260 | + bonddims = [1, 2, 3, 2, 1] |
| 261 | + @assert length(bonddims) == N + 1 |
| 262 | + localdims = [2, 3, 3, 2] |
| 263 | + |
| 264 | + tt = TCI.TensorTrain{ValueType,3}([rand(bonddims[n], localdims[n], bonddims[n+1]) for n in 1:N]) |
| 265 | + ttc = TCI.TTCache(tt.T) |
| 266 | + |
| 267 | + tci2, ranks, errors = TCI.crossinterpolate2( |
| 268 | + ValueType, |
| 269 | + ttc, |
| 270 | + localdims; |
| 271 | + tolerance=1e-10, |
| 272 | + maxbonddim = 10 |
| 273 | + ) |
| 274 | + |
| 275 | + tt_reconst = TCI.TensorTrain(tci2) |
| 276 | + |
| 277 | + vals_reconst = [tt_reconst(collect(indices)) for indices in Iterators.product((1:d for d in localdims)...)] |
| 278 | + vals_ref = [tt(collect(indices)) for indices in Iterators.product((1:d for d in localdims)...)] |
| 279 | + |
| 280 | + @test vals_reconst ≈ vals_ref |
| 281 | + end |
| 282 | + ==# |
| 283 | +end |
0 commit comments