Skip to content

Commit 4be3b5e

Browse files
committed
WIP
1 parent 85be933 commit 4be3b5e

File tree

2 files changed

+371
-0
lines changed

2 files changed

+371
-0
lines changed

test/matrixmul_tests.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
using Test
2+
using LinearAlgebra
3+
import TensorCrossInterpolation as TCI
4+
import TCIAlgorithms: MatrixProduct
5+
import TCIAlgorithms as TCIA
6+
using TCIITensorConversion
7+
8+
using ITensors
9+
10+
#==
11+
==#
12+
13+
function _tomat(tto::TCI.TensorTrain{T,4}) where {T}
14+
sitedims = TCI.sitedims(tto)
15+
localdims1 = [s[1] for s in sitedims]
16+
localdims2 = [s[2] for s in sitedims]
17+
mat = Matrix{T}(undef, prod(localdims1), prod(localdims2))
18+
for (i, inds1) in enumerate(CartesianIndices(Tuple(localdims1)))
19+
for (j, inds2) in enumerate(CartesianIndices(Tuple(localdims2)))
20+
mat[i, j] = TCI.evaluate(tto, collect(zip(Tuple(inds1), Tuple(inds2))))
21+
end
22+
end
23+
return mat
24+
end
25+
26+
@testset "MPO-MPO naive contraction" begin
27+
N = 4
28+
bonddims_a = [1, 2, 3, 2, 1]
29+
bonddims_b = [1, 2, 3, 2, 1]
30+
localdims1 = [2, 2, 2, 2]
31+
localdims2 = [3, 3, 3, 3]
32+
localdims3 = [2, 2, 2, 2]
33+
34+
a = TCI.TensorTrain{ComplexF64,4}([
35+
rand(ComplexF64, bonddims_a[n], localdims1[n], localdims2[n], bonddims_a[n+1])
36+
for n = 1:N
37+
])
38+
b = TCI.TensorTrain{ComplexF64,4}([
39+
rand(ComplexF64, bonddims_b[n], localdims2[n], localdims3[n], bonddims_b[n+1])
40+
for n = 1:N
41+
])
42+
43+
ab = TCIA.naivecontract(a, b)
44+
45+
sites1 = Index.(localdims1, "1")
46+
sites2 = Index.(localdims2, "2")
47+
sites3 = Index.(localdims3, "3")
48+
49+
#amps = MPO(a, sites = collect(zip(sites1, sites2)))
50+
#bmps = MPO(b, sites = collect(zip(sites2, sites3)))
51+
#abmps = amps * bmps
52+
53+
@test _tomat(ab) _tomat(a) * _tomat(b)
54+
55+
#for inds1 in CartesianIndices(Tuple(localdims1))
56+
#for inds3 in CartesianIndices(Tuple(localdims3))
57+
#refvalue = evaluate_mps(
58+
#abmps,
59+
#collect(zip(sites1, Tuple(inds1))),
60+
#collect(zip(sites3, Tuple(inds3))),
61+
#)
62+
#inds = collect(zip(Tuple(inds1), Tuple(inds3)))
63+
#@test ab(inds) ≈ refvalue
64+
#end
65+
#end
66+
end
67+
68+
@testset "MPO-MPO contraction" for f in [x -> x, x -> 2 * x]
69+
N = 4
70+
bonddims_a = [1, 2, 3, 2, 1]
71+
bonddims_b = [1, 2, 3, 2, 1]
72+
localdims1 = [2, 2, 2, 2]
73+
localdims2 = [3, 3, 3, 3]
74+
localdims3 = [2, 2, 2, 2]
75+
76+
a = TCI.TensorTrain{ComplexF64,4}([
77+
rand(ComplexF64, bonddims_a[n], localdims1[n], localdims2[n], bonddims_a[n+1])
78+
for n = 1:N
79+
])
80+
b = TCI.TensorTrain{ComplexF64,4}([
81+
rand(ComplexF64, bonddims_b[n], localdims2[n], localdims3[n], bonddims_b[n+1])
82+
for n = 1:N
83+
])
84+
85+
ab = TCIA.contract(a, b; f = f)
86+
@test TCI.sitedims(ab) == [[localdims1[i], localdims3[i]] for i = 1:N]
87+
@test _tomat(ab) f.(_tomat(a) * _tomat(b))
88+
end

test/test_tensorci2.jl.bak

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
using Test
2+
import TensorCrossInterpolation as TCI
3+
import TensorCrossInterpolation: rank, linkdims, TensorCI2, updatepivots!, addglobalpivots1sitesweep!, MultiIndex, evaluate, SweepStrategies, crossinterpolate2, pivoterror, tensortrain
4+
import Random
5+
import QuanticsGrids as QD
6+
7+
@testset "TensorCI2" begin
8+
#==
9+
@testset "kronecker util function" begin
10+
multiset = [collect(1:5) for _ in 1:5]
11+
localdim = 4
12+
localset = collect(1:localdim)
13+
14+
c = TCI.kronecker(multiset, localdim)
15+
for (i, ci) in enumerate(c)
16+
@test ci[1:5] == collect(1:5)
17+
@test ci[6] in localset
18+
end
19+
20+
d = TCI.kronecker(localdim, multiset)
21+
for (i, di) in enumerate(d)
22+
@test di[1] in localset
23+
@test di[2:6] == collect(1:5)
24+
end
25+
end
26+
27+
@testset "trivial MPS(exp): pivotsearch=$pivotsearch" for pivotsearch in [:full, :rook]
28+
# f(x) = exp(-x)
29+
Random.seed!(1240)
30+
R = 8
31+
abstol = 1e-4
32+
33+
grid = QD.DiscretizedGrid{1}(R, (0.0,), (1.0,))
34+
35+
#index_to_x(i) = (i - 1) / 2^R # x ∈ [0, 1)
36+
fx(x) = exp(-x)
37+
f(bitlist::MultiIndex) = fx(QD.quantics_to_origcoord(grid, bitlist)[1])
38+
39+
localdims = fill(2, R)
40+
firstpivots = [ones(Int, R), vcat(1, fill(2, R - 1))]
41+
tci, ranks, errors = crossinterpolate2(
42+
Float64,
43+
f,
44+
localdims,
45+
firstpivots;
46+
tolerance=abstol,
47+
maxbonddim=1,
48+
maxiter=2,
49+
loginterval=1,
50+
verbosity=0,
51+
normalizeerror=false
52+
)
53+
54+
@test all(TCI.linkdims(tci) .== 1)
55+
56+
for x in [0.1, 0.3, 0.6, 0.9]
57+
indexset = QD.origcoord_to_quantics(
58+
grid, (x,)
59+
)
60+
@test abs(TCI.evaluate(tci, indexset) - f(indexset)) < abstol
61+
end
62+
63+
end
64+
65+
@testset "trivial MPS" begin
66+
n = 5
67+
f(v) = sum(v) * 0.5
68+
69+
tci = TensorCI2{Float64}(fill(2, n))
70+
@test length(tci) == n
71+
@test rank(tci) == 0
72+
@test linkdims(tci) == fill(0, n - 1)
73+
for i in 1:n
74+
@test isempty(tci.Iset[i])
75+
@test isempty(tci.Jset[i])
76+
end
77+
78+
tci = TCI.TensorCI2{Float64}(f, fill(2, n), [fill(1, n)])
79+
@test length(tci) == n
80+
@test rank(tci) == 1
81+
@test linkdims(tci) == fill(1, n - 1)
82+
end
83+
84+
@testset "Lorentz MPS with ValueType=$(typeof(coeff)), pivotsearch=$pivotsearch" for coeff in [1.0, 0.5 - 1.0im], pivotsearch in [:full, :rook]
85+
n = 5
86+
f(v) = coeff ./ (sum(v .^ 2) + 1)
87+
88+
ValueType = typeof(coeff)
89+
90+
tci = TensorCI2{ValueType}(f, fill(10, n))
91+
92+
@test linkdims(tci) == ones(n - 1)
93+
@test rank(tci) == 1
94+
@test length(tci.Iset[1]) == 1
95+
@test length(tci.Jset[end]) == 1
96+
97+
for p in 1:n-1
98+
updatepivots!(tci, p, f, true; reltol=1e-8, maxbonddim=2, pivotsearch)
99+
end
100+
@test linkdims(tci) == fill(2, n - 1)
101+
@test rank(tci) == 2
102+
@test length(tci.Iset[1]) == 1
103+
@test length(tci.Jset[end]) == 1
104+
105+
globalpivot = [2, 9, 10, 5, 7]
106+
addglobalpivots1sitesweep!(tci, f, [globalpivot], reltol=1e-12)
107+
@test linkdims(tci) == fill(3, n - 1)
108+
@test rank(tci) == 3
109+
@test length(tci.Iset[1]) == 1
110+
@test length(tci.Jset[end]) == 1
111+
112+
for iter in 4:20
113+
for p in 1:n-1
114+
updatepivots!(tci, p, f, true; reltol=1e-8, pivotsearch)
115+
end
116+
end
117+
118+
tci2, ranks, errors = crossinterpolate2(
119+
ValueType,
120+
f,
121+
fill(10, n),
122+
[ones(Int, n)];
123+
tolerance=1e-8,
124+
pivottolerance=1e-8,
125+
maxiter=8,
126+
sweepstrategy=SweepStrategies.forward,
127+
pivotsearch=pivotsearch
128+
)
129+
130+
#@test linkdims(tci) == linkdims(tci2) Too strict
131+
@test rank(tci) == rank(tci2)
132+
133+
tci3, ranks, errors = crossinterpolate2(
134+
ValueType,
135+
f,
136+
fill(10, n),
137+
[ones(Int, n)];
138+
tolerance=1e-12,
139+
maxiter=200,
140+
pivotsearch
141+
)
142+
143+
@test pivoterror(tci3) <= 2e-12
144+
@test all(linkdims(tci3) .<= 200)
145+
@test rank(tci3) <= 200
146+
147+
initialpivots = [
148+
[1, 1, 1, 1, 1],
149+
[10, 8, 10, 4, 4],
150+
[5, 4, 8, 9, 3],
151+
[7, 7, 10, 5, 9],
152+
[7, 7, 10, 5, 9]
153+
]
154+
155+
tci4, ranks, errors = crossinterpolate2(
156+
ValueType,
157+
f,
158+
fill(10, n),
159+
initialpivots;
160+
tolerance=1e-12,
161+
maxiter=200,
162+
pivotsearch
163+
)
164+
165+
@test pivoterror(tci4) <= 2e-12
166+
@test all(linkdims(tci4) .<= 200)
167+
@test rank(tci4) <= 200
168+
169+
tt3 = tensortrain(tci3)
170+
171+
for v in Iterators.product([1:3 for p in 1:n]...)
172+
value = evaluate(tci3, [i for i in v])
173+
@test value ≈ prod([tt3[p][:, v[p], :] for p in eachindex(v)])[1]
174+
@test value ≈ f(v)
175+
end
176+
end
177+
==#
178+
179+
@testset "insert_global_pivots: pivotsearch=$pivotsearch" for pivotsearch in [:full], partialnesting in [false, true]
180+
Random.seed!(1234)
181+
182+
R = 20
183+
abstol = 1e-4
184+
grid = QD.DiscretizedGrid{1}(R, (0.0,), (1.0,))
185+
186+
rindex = [rand(1:2, R) for _ in 1:100]
187+
188+
f(bitlist) = fx(QD.quantics_to_origcoord(grid, bitlist)[1])
189+
rpoint = Float64[QD.quantics_to_origcoord(grid, r)[1] for r in rindex]
190+
191+
function fx(x)
192+
res = exp(-10 * x)
193+
for r in rpoint
194+
res += abs(x - r) < 1e-5 ? 2 * abstol : 0.0
195+
end
196+
res
197+
end
198+
199+
localdims = fill(2, R)
200+
firstpivot = ones(Int, R)
201+
tci, ranks, errors = crossinterpolate2(
202+
Float64,
203+
f,
204+
localdims,
205+
[firstpivot];
206+
tolerance=abstol,
207+
maxbonddim=1000,
208+
maxiter=20,
209+
loginterval=1,
210+
verbosity=0,
211+
normalizeerror=false,
212+
pivotsearch=pivotsearch,
213+
partialnesting=true
214+
)
215+
#@show sum(abs.([TCI.evaluate(tci, r) - f(r) for r in rindex]) .> abstol)
216+
217+
TCI.addglobalpivots2sitesweep!(
218+
tci, f, rindex,
219+
tolerance=abstol,
220+
normalizeerror=false,
221+
maxbonddim=1000,
222+
pivotsearch=pivotsearch,
223+
verbosity=1,
224+
partialnesting=partialnesting,
225+
ntry = (!partialnesting && pivotsearch == :full) ? 1 : 10
226+
)
227+
@test sum(abs.([TCI.evaluate(tci, r) - f(r) for r in rindex]) .> abstol) == 0
228+
end
229+
230+
#==
231+
@testset "globalsearch" begin
232+
Random.seed!(1234)
233+
234+
n = 10
235+
fx(x) = exp(-10 * x) * sin(2 * pi * 100 * x^1.1) # Nasty function
236+
f(bitlist) = fx(QD.quantics_to_origcoord(grid, bitlist)[1])
237+
grid = QD.DiscretizedGrid{1}(n, (0.0,), (1.0,))
238+
239+
localdims = fill(2, n)
240+
241+
# This checks only that the function runs without error
242+
tci, ranks, errors = crossinterpolate2(
243+
Float64,
244+
f,
245+
localdims,
246+
tolerance=1e-12,
247+
maxbonddim=100,
248+
maxiter=100,
249+
nsearchglobalpivot=10
250+
)
251+
252+
@test errors[end] < 1e-10
253+
end
254+
255+
256+
@testset "crossinterpolate2_ttcache" begin
257+
ValueType = Float64
258+
259+
N = 4
260+
bonddims = [1, 2, 3, 2, 1]
261+
@assert length(bonddims) == N + 1
262+
localdims = [2, 3, 3, 2]
263+
264+
tt = TCI.TensorTrain{ValueType,3}([rand(bonddims[n], localdims[n], bonddims[n+1]) for n in 1:N])
265+
ttc = TCI.TTCache(tt.T)
266+
267+
tci2, ranks, errors = TCI.crossinterpolate2(
268+
ValueType,
269+
ttc,
270+
localdims;
271+
tolerance=1e-10,
272+
maxbonddim = 10
273+
)
274+
275+
tt_reconst = TCI.TensorTrain(tci2)
276+
277+
vals_reconst = [tt_reconst(collect(indices)) for indices in Iterators.product((1:d for d in localdims)...)]
278+
vals_ref = [tt(collect(indices)) for indices in Iterators.product((1:d for d in localdims)...)]
279+
280+
@test vals_reconst ≈ vals_ref
281+
end
282+
==#
283+
end

0 commit comments

Comments
 (0)