@@ -274,7 +274,7 @@ function (obj::Contraction{T})(
274
274
tmp2 = _contract (tmp1, b[n], (2 , 5 ), (1 , 2 ))
275
275
276
276
# (left_index, S, site[n], link_a', site'[n], link_b')
277
- # => (left_index, link_a', link_b', S, site[n], site'[n])
277
+ # => (left_index, link_a', link_b', S, site[n], site'[n])
278
278
tmp3 = permutedims (tmp2, (1 , 4 , 6 , 2 , 3 , 5 ))
279
279
280
280
leftobj = reshape (tmp3, size (tmp3)[1 : 3 ]. .. , :)
@@ -299,27 +299,40 @@ function (obj::Contraction{T})(
299
299
end
300
300
301
301
302
- function contract_naive (a:: TensorTrain{T,4} , b:: TensorTrain{T,4} ):: TensorTrain{T,4} where {T}
303
- return contract_naive (Contraction (a, b))
302
+ function _contractsitetensors (a:: Array{T, 4} , b:: Array{T, 4} ):: Array{T, 4} where {T}
303
+ # indices: (link_a, s1, s2, link_a') * (link_b, s2, s3, link_b')
304
+ ab:: Array{T, 6} = _contract (a, b, (3 ,), (2 ,))
305
+ # => indices: (link_a, s1, link_a', link_b, s3, link_b')
306
+ abpermuted = permutedims (ab, (1 , 4 , 2 , 5 , 3 , 6 ))
307
+ # => indices: (link_a, link_b, s1, s3, link_a', link_b')
308
+ return reshape (abpermuted,
309
+ size (a, 1 ) * size (b, 1 ), # link_a * link_b
310
+ size (a, 2 ), size (b, 3 ), # s1, s3
311
+ size (a, 4 ) * size (b, 4 ) # link_a' * link_b'
312
+ )
313
+ end
314
+
315
+ function contract_naive (
316
+ a:: TensorTrain{T,4} , b:: TensorTrain{T,4} ;
317
+ tolerance= 0.0 , maxbonddim= typemax (Int)
318
+ ):: TensorTrain{T,4} where {T}
319
+ return contract_naive (Contraction (a, b); tolerance, maxbonddim)
304
320
end
305
321
306
- function contract_naive (obj:: Contraction{T} ):: TensorTrain{T,4} where {T}
322
+ function contract_naive (
323
+ obj:: Contraction{T} ;
324
+ tolerance= 0.0 , maxbonddim= typemax (Int)
325
+ ):: TensorTrain{T,4} where {T}
307
326
if obj. f isa Function
308
327
error (" Cannot contract matrix product with a function." )
309
328
end
310
329
311
330
a, b = obj. mpo
312
-
313
- linkdims_a = vcat (1 , linkdims (a), 1 )
314
- linkdims_b = vcat (1 , linkdims (b), 1 )
315
- linkdims_ab = linkdims_a .* linkdims_b
316
-
317
- # (link_a, s1, s2, link_a') * (link_b, s2, s3, link_b')
318
- # => (link_a, s1, link_a', link_b, s3, link_b')
319
- # => (link_a, link_b, s1, s3, link_a', link_b')
320
- sitetensors = [reshape (permutedims (_contract (obj. mpo[1 ][n], obj. mpo[2 ][n], (3 ,), (2 ,)), (1 , 4 , 2 , 5 , 3 , 6 )), linkdims_ab[n], obj. sitedims[n]. .. , linkdims_ab[n+ 1 ]) for n = 1 : length (obj)]
321
-
322
- return TensorTrain {T,4} (sitetensors)
331
+ tt = TensorTrain {T, 4} (_contractsitetensors .(sitetensors (a), sitetensors (b)))
332
+ if tolerance > 0 || maxbonddim < typemax (Int)
333
+ compress! (tt, :SVD ; tolerance, maxbonddim)
334
+ end
335
+ return tt
323
336
end
324
337
325
338
function _reshape_fusesites (t:: AbstractArray{T} ) where {T}
@@ -389,15 +402,15 @@ end
389
402
function contract (
390
403
A:: TensorTrain{ValueType,4} ,
391
404
B:: TensorTrain{ValueType,4} ;
392
- algorithm= " TCI" ,
405
+ algorithm:: Symbol = : TCI ,
393
406
tolerance:: Float64 = 1e-12 ,
394
407
maxbonddim:: Int = typemax (Int),
395
408
f:: Union{Nothing,Function} = nothing ,
396
409
kwargs...
397
410
) where {ValueType}
398
- if algorithm == " TCI"
411
+ if algorithm === : TCI
399
412
return contract_TCI (A, B; tolerance= tolerance, maxbonddim= maxbonddim, f= f, kwargs... )
400
- elseif algorithm == " naive"
413
+ elseif algorithm === : naive
401
414
return contract_naive (A, B)
402
415
else
403
416
throw (ArgumentError (" Unknown algorithm $algorithm ." ))
0 commit comments