Skip to content

Commit 6cf1a0e

Browse files
committed
use new constructors for conversion in contract()
1 parent e899682 commit 6cf1a0e

File tree

2 files changed

+33
-36
lines changed

2 files changed

+33
-36
lines changed

src/contraction.jl

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -299,15 +299,15 @@ function (obj::Contraction{T})(
299299
end
300300

301301

302-
function _contractsitetensors(a::Array{T, 4}, b::Array{T, 4})::Array{T, 4} where {T}
302+
function _contractsitetensors(a::Array{T,4}, b::Array{T,4})::Array{T,4} where {T}
303303
# indices: (link_a, s1, s2, link_a') * (link_b, s2, s3, link_b')
304-
ab::Array{T, 6} = _contract(a, b, (3,), (2,))
304+
ab::Array{T,6} = _contract(a, b, (3,), (2,))
305305
# => indices: (link_a, s1, link_a', link_b, s3, link_b')
306306
abpermuted = permutedims(ab, (1, 4, 2, 5, 3, 6))
307307
# => indices: (link_a, link_b, s1, s3, link_a', link_b')
308308
return reshape(abpermuted,
309309
size(a, 1) * size(b, 1), # link_a * link_b
310-
size(a, 2), size(b, 3), # s1, s3
310+
size(a, 2), size(b, 3), # s1, s3
311311
size(a, 4) * size(b, 4) # link_a' * link_b'
312312
)
313313
end
@@ -328,7 +328,7 @@ function contract_naive(
328328
end
329329

330330
a, b = obj.mpo
331-
tt = TensorTrain{T, 4}(_contractsitetensors.(sitetensors(a), sitetensors(b)))
331+
tt = TensorTrain{T,4}(_contractsitetensors.(sitetensors(a), sitetensors(b)))
332332
if tolerance > 0 || maxbonddim < typemax(Int)
333333
compress!(tt, :SVD; tolerance, maxbonddim)
334334
end
@@ -445,27 +445,20 @@ function contract(
445445
end
446446
end
447447

448-
function _promoteMPStoMPO(
449-
tt::AbstractTensorTrain{V},
450-
unfusedlocalshape::Union{AbstractVector{Int}, Tuple}
451-
)::TensorTrain{V, 4} where {V}
452-
return TensorTrain{V, 4}(_reshape_splitsites.(sitetensors(tt), Ref(unfusedlocalshape)))
453-
end
454-
455448
function contract(
456-
A::Union{TensorCI1{V}, TensorCI2{V}, TensorTrain{V, 3}},
457-
B::TensorTrain{V, 4};
449+
A::Union{TensorCI1{V},TensorCI2{V},TensorTrain{V,3}},
450+
B::TensorTrain{V,4};
458451
kwargs...
459-
)::TensorTrain{V, 3} where {V}
460-
tt = contract(_promoteMPStoMPO(A, (1, sitedim(A, 1)...)), B; kwargs...)
461-
return TensorTrain{V, 3}([T for (T, shape) in _reshape_fusesites.(sitetensors(tt))])
452+
)::TensorTrain{V,3} where {V}
453+
tt = contract(TensorTrain{4}(A, [(1, s...) for s in sitedims(A)]), B; kwargs...)
454+
return TensorTrain{3}(tt, prod.(sitedims(tt)))
462455
end
463456

464457
function contract(
465-
A::TensorTrain{V, 4},
466-
B::Union{TensorCI1{V}, TensorCI2{V}, TensorTrain{V, 3}};
458+
A::TensorTrain{V,4},
459+
B::Union{TensorCI1{V},TensorCI2{V},TensorTrain{V,3}};
467460
kwargs...
468-
)::TensorTrain{V, 3} where {V}
469-
tt = contract(A, _promoteMPStoMPO(B, (sitedim(B, 1)..., 1)); kwargs...)
470-
return TensorTrain{V, 3}([T for (T, shape) in _reshape_fusesites.(sitetensors(tt))])
461+
)::TensorTrain{V,3} where {V}
462+
tt = contract(A, TensorTrain{4}(B, [(s..., 1) for s in sitedims(B)]); kwargs...)
463+
return TensorTrain{3}(tt, prod.(sitedims(tt)))
471464
end

src/tensortrain.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ Arguments:
6262
- `localdims`: a vector of local dimensions for each tensor in the tensor train. A each element
6363
of `localdims` should be an array-like object of `N-2` integers.
6464
"""
65-
function TensorTrain{N}(tt::AbstractTensorTrain{V}, localdims)::TensorTrain{V,N} where {V,N}
65+
function TensorTrain{V,N}(tt::AbstractTensorTrain{V}, localdims)::TensorTrain{V,N} where {V,N}
6666
for d in localdims
67-
length(d) == N-2 || error("Each element of localdims be a list of N-2 integers.")
67+
length(d) == N - 2 || error("Each element of localdims be a list of N-2 integers.")
6868
end
6969
for n in 1:length(tt)
7070
prod(size(tt[n])[2:end-1]) == prod(localdims[n]) || error("The local dimensions at n=$n must match the tensor sizes.")
@@ -73,13 +73,17 @@ function TensorTrain{N}(tt::AbstractTensorTrain{V}, localdims)::TensorTrain{V,N}
7373
[reshape(t, size(t, 1), localdims[n]..., size(t)[end]) for (n, t) in enumerate(sitetensors(tt))])
7474
end
7575

76+
function TensorTrain{N}(tt::AbstractTensorTrain{V}, localdims)::TensorTrain{V,N} where {V,N}
77+
return TensorTrain{V,N}(tt, localdims)
78+
end
79+
7680
function tensortrain(tci)
7781
return TensorTrain(tci)
7882
end
7983

8084
function _factorize(
8185
A::Matrix{V}, method::Symbol; tolerance::Float64, maxbonddim::Int
82-
)::Tuple{Matrix{V}, Matrix{V}, Int} where {V}
86+
)::Tuple{Matrix{V},Matrix{V},Int} where {V}
8387
if method === :LU
8488
factorization = rrlu(A, abstol=tolerance, maxrank=maxbonddim)
8589
return left(factorization), right(factorization), npivots(factorization)
@@ -113,11 +117,11 @@ end
113117
Compress the tensor train `tt` using `LU`, `CI` or `SVD` decompositions.
114118
"""
115119
function compress!(
116-
tt::TensorTrain{V, N},
120+
tt::TensorTrain{V,N},
117121
method::Symbol=:LU;
118122
tolerance::Float64=1e-12,
119123
maxbonddim::Int=typemax(Int)
120-
) where {V, N}
124+
) where {V,N}
121125
for ell in 1:length(tt)-1
122126
shapel = size(tt.sitetensors[ell])
123127
left, right, newbonddim = _factorize(
@@ -146,48 +150,48 @@ function compress!(
146150
end
147151

148152

149-
function multiply!(tt::TensorTrain{V, N}, a) where {V, N}
153+
function multiply!(tt::TensorTrain{V,N}, a) where {V,N}
150154
tt.sitetensors[end] .= tt.sitetensors[end] .* a
151155
nothing
152156
end
153157

154-
function multiply!(a, tt::TensorTrain{V, N}) where {V, N}
158+
function multiply!(a, tt::TensorTrain{V,N}) where {V,N}
155159
tt.sitetensors[end] .= a .* tt.sitetensors[end]
156160
nothing
157161
end
158162

159-
function multiply(tt::TensorTrain{V, N}, a)::TensorTrain{V, N} where {V, N}
163+
function multiply(tt::TensorTrain{V,N}, a)::TensorTrain{V,N} where {V,N}
160164
tt2 = deepcopy(tt)
161165
multiply!(tt2, a)
162166
return tt2
163167
end
164168

165-
function multiply(a, tt::TensorTrain{V, N})::TensorTrain{V, N} where {V, N}
169+
function multiply(a, tt::TensorTrain{V,N})::TensorTrain{V,N} where {V,N}
166170
tt2 = deepcopy(tt)
167171
multiply!(a, tt2)
168172
return tt2
169173
end
170174

171-
function Base.:*(tt::TensorTrain{V, N}, a)::TensorTrain{V, N} where {V, N}
175+
function Base.:*(tt::TensorTrain{V,N}, a)::TensorTrain{V,N} where {V,N}
172176
return multiply(tt, a)
173177
end
174178

175-
function Base.:*(a, tt::TensorTrain{V, N})::TensorTrain{V, N} where {V, N}
179+
function Base.:*(a, tt::TensorTrain{V,N})::TensorTrain{V,N} where {V,N}
176180
return multiply(a, tt)
177181
end
178182

179-
function divide!(tt::TensorTrain{V, N}, a) where {V, N}
183+
function divide!(tt::TensorTrain{V,N}, a) where {V,N}
180184
tt.sitetensors[end] .= tt.sitetensors[end] ./ a
181185
nothing
182186
end
183187

184-
function divide(tt::TensorTrain{V, N}, a) where {V, N}
188+
function divide(tt::TensorTrain{V,N}, a) where {V,N}
185189
tt2 = deepcopy(tt)
186190
divide!(tt2, a)
187191
return tt2
188192
end
189193

190-
function Base.:/(tt::TensorTrain{V, N}, a) where {V, N}
194+
function Base.:/(tt::TensorTrain{V,N}, a) where {V,N}
191195
return divide(tt, a)
192196
end
193197

@@ -222,7 +226,7 @@ function to_tensors(obj::TensorTrainFit{ValueType}, x::Vector{ValueType}) where
222226
]
223227
end
224228

225-
function _evaluate(tt::Vector{Array{V, 3}}, indexset) where {V}
229+
function _evaluate(tt::Vector{Array{V,3}}, indexset) where {V}
226230
only(prod(T[:, i, :] for (T, i) in zip(tt, indexset)))
227231
end
228232

0 commit comments

Comments
 (0)