Skip to content

Commit 313515d

Browse files
committed
Merge branch '42-conversion-from-3-leg-tt-to-4-leg-tt-for-ttos' into 'main'
Implement conversion between TTs with different shapes Closes #42 See merge request tensors4fields/TensorCrossInterpolation.jl!83
2 parents 4803711 + 6cf1a0e commit 313515d

File tree

3 files changed

+69
-34
lines changed

3 files changed

+69
-34
lines changed

src/contraction.jl

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -299,15 +299,15 @@ function (obj::Contraction{T})(
299299
end
300300

301301

302-
function _contractsitetensors(a::Array{T, 4}, b::Array{T, 4})::Array{T, 4} where {T}
302+
function _contractsitetensors(a::Array{T,4}, b::Array{T,4})::Array{T,4} where {T}
303303
# indices: (link_a, s1, s2, link_a') * (link_b, s2, s3, link_b')
304-
ab::Array{T, 6} = _contract(a, b, (3,), (2,))
304+
ab::Array{T,6} = _contract(a, b, (3,), (2,))
305305
# => indices: (link_a, s1, link_a', link_b, s3, link_b')
306306
abpermuted = permutedims(ab, (1, 4, 2, 5, 3, 6))
307307
# => indices: (link_a, link_b, s1, s3, link_a', link_b')
308308
return reshape(abpermuted,
309309
size(a, 1) * size(b, 1), # link_a * link_b
310-
size(a, 2), size(b, 3), # s1, s3
310+
size(a, 2), size(b, 3), # s1, s3
311311
size(a, 4) * size(b, 4) # link_a' * link_b'
312312
)
313313
end
@@ -328,7 +328,7 @@ function contract_naive(
328328
end
329329

330330
a, b = obj.mpo
331-
tt = TensorTrain{T, 4}(_contractsitetensors.(sitetensors(a), sitetensors(b)))
331+
tt = TensorTrain{T,4}(_contractsitetensors.(sitetensors(a), sitetensors(b)))
332332
if tolerance > 0 || maxbonddim < typemax(Int)
333333
compress!(tt, :SVD; tolerance, maxbonddim)
334334
end
@@ -445,27 +445,20 @@ function contract(
445445
end
446446
end
447447

448-
function _promoteMPStoMPO(
449-
tt::AbstractTensorTrain{V},
450-
unfusedlocalshape::Union{AbstractVector{Int}, Tuple}
451-
)::TensorTrain{V, 4} where {V}
452-
return TensorTrain{V, 4}(_reshape_splitsites.(sitetensors(tt), Ref(unfusedlocalshape)))
453-
end
454-
455448
function contract(
456-
A::Union{TensorCI1{V}, TensorCI2{V}, TensorTrain{V, 3}},
457-
B::TensorTrain{V, 4};
449+
A::Union{TensorCI1{V},TensorCI2{V},TensorTrain{V,3}},
450+
B::TensorTrain{V,4};
458451
kwargs...
459-
)::TensorTrain{V, 3} where {V}
460-
tt = contract(_promoteMPStoMPO(A, (1, sitedim(A, 1)...)), B; kwargs...)
461-
return TensorTrain{V, 3}([T for (T, shape) in _reshape_fusesites.(sitetensors(tt))])
452+
)::TensorTrain{V,3} where {V}
453+
tt = contract(TensorTrain{4}(A, [(1, s...) for s in sitedims(A)]), B; kwargs...)
454+
return TensorTrain{3}(tt, prod.(sitedims(tt)))
462455
end
463456

464457
function contract(
465-
A::TensorTrain{V, 4},
466-
B::Union{TensorCI1{V}, TensorCI2{V}, TensorTrain{V, 3}};
458+
A::TensorTrain{V,4},
459+
B::Union{TensorCI1{V},TensorCI2{V},TensorTrain{V,3}};
467460
kwargs...
468-
)::TensorTrain{V, 3} where {V}
469-
tt = contract(A, _promoteMPStoMPO(B, (sitedim(B, 1)..., 1)); kwargs...)
470-
return TensorTrain{V, 3}([T for (T, shape) in _reshape_fusesites.(sitetensors(tt))])
461+
)::TensorTrain{V,3} where {V}
462+
tt = contract(A, TensorTrain{4}(B, [(s..., 1) for s in sitedims(B)]); kwargs...)
463+
return TensorTrain{3}(tt, prod.(sitedims(tt)))
471464
end

src/tensortrain.jl

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,38 @@ function TensorTrain(tci::AbstractTensorTrain{V})::TensorTrain{V,3} where {V}
5252
return TensorTrain{V,3}(sitetensors(tci))
5353
end
5454

55+
"""
56+
function TensorTrain{N}(tci::AbstractTensorTrain{V}) where {V,N}
57+
58+
Convert a tensor-train-like object into a tensor train.
59+
60+
Arguments:
61+
- `tt::AbstractTensorTrain{V}`: a tensor-train-like object.
62+
- `localdims`: a vector of local dimensions for each tensor in the tensor train. A each element
63+
of `localdims` should be an array-like object of `N-2` integers.
64+
"""
65+
function TensorTrain{V,N}(tt::AbstractTensorTrain{V}, localdims)::TensorTrain{V,N} where {V,N}
66+
for d in localdims
67+
length(d) == N - 2 || error("Each element of localdims be a list of N-2 integers.")
68+
end
69+
for n in 1:length(tt)
70+
prod(size(tt[n])[2:end-1]) == prod(localdims[n]) || error("The local dimensions at n=$n must match the tensor sizes.")
71+
end
72+
return TensorTrain{V,N}(
73+
[reshape(t, size(t, 1), localdims[n]..., size(t)[end]) for (n, t) in enumerate(sitetensors(tt))])
74+
end
75+
76+
function TensorTrain{N}(tt::AbstractTensorTrain{V}, localdims)::TensorTrain{V,N} where {V,N}
77+
return TensorTrain{V,N}(tt, localdims)
78+
end
79+
5580
function tensortrain(tci)
5681
return TensorTrain(tci)
5782
end
5883

5984
function _factorize(
6085
A::Matrix{V}, method::Symbol; tolerance::Float64, maxbonddim::Int
61-
)::Tuple{Matrix{V}, Matrix{V}, Int} where {V}
86+
)::Tuple{Matrix{V},Matrix{V},Int} where {V}
6287
if method === :LU
6388
factorization = rrlu(A, abstol=tolerance, maxrank=maxbonddim)
6489
return left(factorization), right(factorization), npivots(factorization)
@@ -92,11 +117,11 @@ end
92117
Compress the tensor train `tt` using `LU`, `CI` or `SVD` decompositions.
93118
"""
94119
function compress!(
95-
tt::TensorTrain{V, N},
120+
tt::TensorTrain{V,N},
96121
method::Symbol=:LU;
97122
tolerance::Float64=1e-12,
98123
maxbonddim::Int=typemax(Int)
99-
) where {V, N}
124+
) where {V,N}
100125
for ell in 1:length(tt)-1
101126
shapel = size(tt.sitetensors[ell])
102127
left, right, newbonddim = _factorize(
@@ -125,48 +150,48 @@ function compress!(
125150
end
126151

127152

128-
function multiply!(tt::TensorTrain{V, N}, a) where {V, N}
153+
function multiply!(tt::TensorTrain{V,N}, a) where {V,N}
129154
tt.sitetensors[end] .= tt.sitetensors[end] .* a
130155
nothing
131156
end
132157

133-
function multiply!(a, tt::TensorTrain{V, N}) where {V, N}
158+
function multiply!(a, tt::TensorTrain{V,N}) where {V,N}
134159
tt.sitetensors[end] .= a .* tt.sitetensors[end]
135160
nothing
136161
end
137162

138-
function multiply(tt::TensorTrain{V, N}, a)::TensorTrain{V, N} where {V, N}
163+
function multiply(tt::TensorTrain{V,N}, a)::TensorTrain{V,N} where {V,N}
139164
tt2 = deepcopy(tt)
140165
multiply!(tt2, a)
141166
return tt2
142167
end
143168

144-
function multiply(a, tt::TensorTrain{V, N})::TensorTrain{V, N} where {V, N}
169+
function multiply(a, tt::TensorTrain{V,N})::TensorTrain{V,N} where {V,N}
145170
tt2 = deepcopy(tt)
146171
multiply!(a, tt2)
147172
return tt2
148173
end
149174

150-
function Base.:*(tt::TensorTrain{V, N}, a)::TensorTrain{V, N} where {V, N}
175+
function Base.:*(tt::TensorTrain{V,N}, a)::TensorTrain{V,N} where {V,N}
151176
return multiply(tt, a)
152177
end
153178

154-
function Base.:*(a, tt::TensorTrain{V, N})::TensorTrain{V, N} where {V, N}
179+
function Base.:*(a, tt::TensorTrain{V,N})::TensorTrain{V,N} where {V,N}
155180
return multiply(a, tt)
156181
end
157182

158-
function divide!(tt::TensorTrain{V, N}, a) where {V, N}
183+
function divide!(tt::TensorTrain{V,N}, a) where {V,N}
159184
tt.sitetensors[end] .= tt.sitetensors[end] ./ a
160185
nothing
161186
end
162187

163-
function divide(tt::TensorTrain{V, N}, a) where {V, N}
188+
function divide(tt::TensorTrain{V,N}, a) where {V,N}
164189
tt2 = deepcopy(tt)
165190
divide!(tt2, a)
166191
return tt2
167192
end
168193

169-
function Base.:/(tt::TensorTrain{V, N}, a) where {V, N}
194+
function Base.:/(tt::TensorTrain{V,N}, a) where {V,N}
170195
return divide(tt, a)
171196
end
172197

@@ -201,7 +226,7 @@ function to_tensors(obj::TensorTrainFit{ValueType}, x::Vector{ValueType}) where
201226
]
202227
end
203228

204-
function _evaluate(tt::Vector{Array{V, 3}}, indexset) where {V}
229+
function _evaluate(tt::Vector{Array{V,3}}, indexset) where {V}
205230
only(prod(T[:, i, :] for (T, i) in zip(tt, indexset)))
206231
end
207232

test/test_tensortrain.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,23 @@ using Optim
3535
end
3636
end
3737

38+
39+
@testset "TT shape conversion" for T in [Float64, ComplexF64]
40+
linkdims = [1, 2, 3, 1]
41+
L = length(linkdims) - 1
42+
localdims = fill(4, L)
43+
tts = TCI.TensorTrain{T,3}([randn(T, linkdims[n], localdims[n], linkdims[n+1]) for n in 1:L])
44+
tto = TCI.TensorTrain{4}(tts, fill([2,2], L))
45+
tts_reconst = TCI.TensorTrain{3}(tto, localdims)
46+
47+
for n in 1:L
48+
@test all(tts[n] .== tts_reconst[n])
49+
end
50+
51+
@test_throws ErrorException TCI.TensorTrain{4}(tts, fill([2,3], L)) # Wrong shape
52+
@test_throws ErrorException TCI.TensorTrain{4}(tts, fill([1,2,3], L)) # Wrong shape
53+
end
54+
3855
@testset "batchevaluate" begin
3956
N = 4
4057
#bonddims = fill(3, N + 1)

0 commit comments

Comments
 (0)