Skip to content

Commit 011bfd5

Browse files
committed
added TensorCI2(TensorTrain) constructor
1 parent cd63b04 commit 011bfd5

File tree

4 files changed

+144
-1
lines changed

4 files changed

+144
-1
lines changed

src/conversion.jl

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,108 @@ function TensorCI2{ValueType}(tci1::TensorCI1{ValueType}) where {ValueType}
6969
tci2.maxsamplevalue = tci1.maxsamplevalue
7070
return tci2
7171
end
72+
73+
function sweep1sitegetindices!(
74+
tt::TensorTrain{ValueType,3}, forwardsweep::Bool,
75+
spectatorindices::Vector{Vector{MultiIndex}}=Vector{MultiIndex}[];
76+
maxbonddim=typemax(Int), tolerance=0.0
77+
) where {ValueType}
78+
indexset = Vector{MultiIndex}[MultiIndex[[]]]
79+
pivoterrorsarray = zeros(rank(tt) + 1)
80+
81+
function groupindices(T::AbstractArray, next::Bool)
82+
shape = size(T)
83+
if forwardsweep != next
84+
reshape(T, prod(shape[1:end-1]), shape[end])
85+
else
86+
reshape(T, shape[1], prod(shape[2:end]))
87+
end
88+
end
89+
90+
function splitindices(T::AbstractArray, shape, newbonddim, next::Bool)
91+
if forwardsweep != next
92+
newshape = (shape[1:end-1]..., newbonddim)
93+
else
94+
newshape = (newbonddim, shape[2:end]...)
95+
end
96+
reshape(T, newshape)
97+
end
98+
99+
L = length(tt)
100+
for i in 1:L-1
101+
ell = forwardsweep ? i : L - i + 1
102+
ellnext = forwardsweep ? i + 1 : L - i
103+
shape = size(tt.sitetensors[ell])
104+
shapenext = size(tt.sitetensors[ellnext])
105+
106+
luci = MatrixLUCI(
107+
groupindices(tt.sitetensors[ell], false), leftorthogonal=forwardsweep,
108+
abstol=tolerance, maxrank=maxbonddim
109+
)
110+
111+
if forwardsweep
112+
push!(indexset, kronecker(last(indexset), shape[2])[rowindices(luci)])
113+
if !isempty(spectatorindices)
114+
spectatorindices[ell] = spectatorindices[ell][colindices(luci)]
115+
end
116+
else
117+
push!(indexset, kronecker(shape[2], last(indexset))[colindices(luci)])
118+
if !isempty(spectatorindices)
119+
spectatorindices[ell] = spectatorindices[ell][rowindices(luci)]
120+
end
121+
end
122+
123+
124+
tt.sitetensors[ell] = splitindices(
125+
forwardsweep ? left(luci) : right(luci),
126+
shape, npivots(luci), false
127+
)
128+
129+
nexttensor = (
130+
forwardsweep
131+
? right(luci) * groupindices(tt.sitetensors[ellnext], true)
132+
: groupindices(tt.sitetensors[ellnext], true) * left(luci)
133+
)
134+
135+
tt.sitetensors[ellnext] = splitindices(nexttensor, shapenext, npivots(luci), true)
136+
pivoterrorsarray[1:npivots(luci) + 1] = max.(pivoterrorsarray[1:npivots(luci) + 1], pivoterrors(luci))
137+
end
138+
139+
if forwardsweep
140+
return indexset, pivoterrorsarray
141+
else
142+
return reverse(indexset), pivoterrorsarray
143+
end
144+
end
145+
146+
function TensorCI2{ValueType}(
147+
tt::TensorTrain{ValueType,3}; tolerance=1e-12, maxbonddim=typemax(Int), maxiter=3
148+
) where {ValueType}
149+
local pivoterrors::Vector{Float64}
150+
151+
Iset, = sweep1sitegetindices!(tt, true; maxbonddim, tolerance)
152+
Jset, pivoterrors = sweep1sitegetindices!(tt, false; maxbonddim, tolerance)
153+
154+
for iter in 3:maxiter
155+
if isodd(iter)
156+
Isetnew, pivoterrors = sweep1sitegetindices!(tt, true, Jset)
157+
if Isetnew == Iset
158+
break
159+
end
160+
else
161+
Jsetnew, pivoterrors = sweep1sitegetindices!(tt, false, Iset)
162+
if Jsetnew == Jset
163+
break
164+
end
165+
end
166+
end
167+
168+
tci2 = TensorCI2{ValueType}(first.(sitedims(tt)))
169+
tci2.Iset = Iset
170+
tci2.Jset = Jset
171+
tci2.sitetensors = sitetensors(tt)
172+
tci2.pivoterrors = pivoterrors
173+
tci2.maxsamplevalue = maximum(maximum.(abs, tci2.sitetensors))
174+
175+
return tci2
176+
end

src/tensortrain.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,12 @@ function Base.:/(tt::TensorTrain{V,N}, a) where {V,N}
195195
return divide(tt, a)
196196
end
197197

198+
function Base.reverse(tt::AbstractTensorTrain{V}) where {V}
199+
return tensortrain(reverse([
200+
permutedims(T, (ndims(T), (2:ndims(T)-1)..., 1)) for T in sitetensors(tt)
201+
]))
202+
end
203+
198204
"""
199205
Fitting data with a TensorTrain object.
200206
This may be useful when the interpolated function is noisy.

test/test_conversion.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,28 @@ end
7171
@test tci2(v) f(v)
7272
end
7373
end
74+
75+
@testset "Conversion between TT and TCI2" begin
76+
f(v) = (1.0 + 2.0im) ./ (sum(v .^ 2) + 1)
77+
tci, = crossinterpolate2(ComplexF64, f, fill(4, 4); tolerance=1e-14, maxbonddim=5)
78+
tt = tensortrain(tci)
79+
tcib = TensorCI2{ComplexF64}(tt; tolerance=1e-14)
80+
81+
@test rank(tt) == 5
82+
@test linkdims(tt) == linkdims(tci)
83+
@test sitedims(tt) == fill([4], 4)
84+
85+
@test rank(tcib) == 5
86+
@test linkdims(tcib) == linkdims(tt)
87+
@test sitedims(tcib) == fill([4], 4)
88+
89+
for v in Iterators.product([1:4 for _ in 1:4]...)
90+
@test abs(tt(v) - tci(v)) < 1e-13
91+
@test abs(tcib(v) - tci(v)) < 1e-13
92+
end
93+
94+
optimize!(tcib, f; tolerance=1e-14)
95+
for v in Iterators.product([1:4 for _ in 1:4]...)
96+
@test abs(tcib(v) - f(v)) < 1e-13
97+
end
98+
end

test/test_tensortrain.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,24 @@ using Optim
1010

1111
tci, ranks, errors = TCI.crossinterpolate1(ComplexF64, g, localdims; tolerance=tolerance)
1212
tt = TCI.TensorTrain(tci)
13+
ttr = TCI.reverse(tt)
1314
@test TCI.rank(tci) == TCI.rank(tt)
15+
@test TCI.rank(tci) == TCI.rank(ttr)
1416
@test TCI.linkdims(tci) == TCI.linkdims(tt)
17+
@test TCI.linkdims(tci) == reverse(TCI.linkdims(ttr))
1518
gsum = ComplexF64(0.0)
1619
for i in allindices
1720
@test TCI.evaluate(tci, i) TCI.evaluate(tt, i)
21+
@test TCI.evaluate(tci, i) TCI.evaluate(ttr, reverse(Tuple(i)))
1822
@test tt(i) == TCI.evaluate(tt, i)
23+
@test tt(i) TCI.evaluate(tt, reverse(Tuple(i)))
1924
functionvalue = g(Tuple(i))
2025
@test abs(TCI.evaluate(tt, i) - functionvalue) < tolerance
26+
@test abs(TCI.evaluate(ttr, reverse(Tuple(i))) - functionvalue) < tolerance
2127
gsum += functionvalue
2228
end
2329
@test gsum TCI.sum(tt)
30+
@test gsum TCI.sum(ttr)
2431

2532
for method in [:LU, :CI, :SVD]
2633
ttcompressed = deepcopy(tt)
@@ -46,7 +53,7 @@ end
4653

4754
for n in 1:L
4855
@test all(tts[n] .== tts_reconst[n])
49-
end
56+
end
5057

5158
@test_throws ErrorException TCI.TensorTrain{4}(tts, fill([2,3], L)) # Wrong shape
5259
@test_throws ErrorException TCI.TensorTrain{4}(tts, fill([1,2,3], L)) # Wrong shape

0 commit comments

Comments
 (0)