Skip to content

Commit fe0eb8a

Browse files
committed
added compression step to naive contraction
1 parent 37dbd7c commit fe0eb8a

File tree

3 files changed

+33
-20
lines changed

3 files changed

+33
-20
lines changed

src/contraction.jl

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ function (obj::Contraction{T})(
274274
tmp2 = _contract(tmp1, b[n], (2, 5), (1, 2))
275275

276276
# (left_index, S, site[n], link_a', site'[n], link_b')
277-
# => (left_index, link_a', link_b', S, site[n], site'[n])
277+
# => (left_index, link_a', link_b', S, site[n], site'[n])
278278
tmp3 = permutedims(tmp2, (1, 4, 6, 2, 3, 5))
279279

280280
leftobj = reshape(tmp3, size(tmp3)[1:3]..., :)
@@ -299,27 +299,40 @@ function (obj::Contraction{T})(
299299
end
300300

301301

302-
function contract_naive(a::TensorTrain{T,4}, b::TensorTrain{T,4})::TensorTrain{T,4} where {T}
303-
return contract_naive(Contraction(a, b))
302+
function _contractsitetensors(a::Array{T, 4}, b::Array{T, 4})::Array{T, 4} where {T}
303+
# indices: (link_a, s1, s2, link_a') * (link_b, s2, s3, link_b')
304+
ab::Array{T, 6} = _contract(a, b, (3,), (2,))
305+
# => indices: (link_a, s1, link_a', link_b, s3, link_b')
306+
abpermuted = permutedims(ab, (1, 4, 2, 5, 3, 6))
307+
# => indices: (link_a, link_b, s1, s3, link_a', link_b')
308+
return reshape(abpermuted,
309+
size(a, 1) * size(b, 1), # link_a * link_b
310+
size(a, 2), size(b, 3), # s1, s3
311+
size(a, 4) * size(b, 4) # link_a' * link_b'
312+
)
313+
end
314+
315+
function contract_naive(
316+
a::TensorTrain{T,4}, b::TensorTrain{T,4};
317+
tolerance=0.0, maxbonddim=typemax(Int)
318+
)::TensorTrain{T,4} where {T}
319+
return contract_naive(Contraction(a, b); tolerance, maxbonddim)
304320
end
305321

306-
function contract_naive(obj::Contraction{T})::TensorTrain{T,4} where {T}
322+
function contract_naive(
323+
obj::Contraction{T};
324+
tolerance=0.0, maxbonddim=typemax(Int)
325+
)::TensorTrain{T,4} where {T}
307326
if obj.f isa Function
308327
error("Cannot contract matrix product with a function.")
309328
end
310329

311330
a, b = obj.mpo
312-
313-
linkdims_a = vcat(1, linkdims(a), 1)
314-
linkdims_b = vcat(1, linkdims(b), 1)
315-
linkdims_ab = linkdims_a .* linkdims_b
316-
317-
# (link_a, s1, s2, link_a') * (link_b, s2, s3, link_b')
318-
# => (link_a, s1, link_a', link_b, s3, link_b')
319-
# => (link_a, link_b, s1, s3, link_a', link_b')
320-
sitetensors = [reshape(permutedims(_contract(obj.mpo[1][n], obj.mpo[2][n], (3,), (2,)), (1, 4, 2, 5, 3, 6)), linkdims_ab[n], obj.sitedims[n]..., linkdims_ab[n+1]) for n = 1:length(obj)]
321-
322-
return TensorTrain{T,4}(sitetensors)
331+
tt = TensorTrain{T, 4}(_contractsitetensors.(sitetensors(a), sitetensors(b)))
332+
if tolerance > 0 || maxbonddim < typemax(Int)
333+
compress!(tt, :SVD; tolerance, maxbonddim)
334+
end
335+
return tt
323336
end
324337

325338
function _reshape_fusesites(t::AbstractArray{T}) where {T}
@@ -389,15 +402,15 @@ end
389402
function contract(
390403
A::TensorTrain{ValueType,4},
391404
B::TensorTrain{ValueType,4};
392-
algorithm="TCI",
405+
algorithm::Symbol=:TCI,
393406
tolerance::Float64=1e-12,
394407
maxbonddim::Int=typemax(Int),
395408
f::Union{Nothing,Function}=nothing,
396409
kwargs...
397410
) where {ValueType}
398-
if algorithm == "TCI"
411+
if algorithm === :TCI
399412
return contract_TCI(A, B; tolerance=tolerance, maxbonddim=maxbonddim, f=f, kwargs...)
400-
elseif algorithm == "naive"
413+
elseif algorithm === :naive
401414
return contract_naive(A, B)
402415
else
403416
throw(ArgumentError("Unknown algorithm $algorithm."))

src/tensortrain.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ end
5858

5959
function _factorize(
6060
A::Matrix{V}, method::Symbol; tolerance::Float64, maxbonddim::Int
61-
) where {V}
61+
)::Tuple{Matrix{V}, Matrix{V}, Int} where {V}
6262
if method === :LU
6363
factorization = rrlu(A, abstol=tolerance, maxrank=maxbonddim)
6464
return left(factorization), right(factorization), npivots(factorization)

test/test_contraction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ end
4848
@test _tomat(ab) f.(_tomat(a) * _tomat(b))
4949
end
5050
end
51-
end
51+
end

0 commit comments

Comments
 (0)