Skip to content

Commit db4e35c

Browse files
authored
Merge pull request #38 from tensor4all/37-reconstructing-a-full-tensor-using-matrix-multiplications
Vectorize fulltensor()
2 parents baf6ca5 + 29e3215 commit db4e35c

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

src/tensortrain.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,17 @@ function (obj::TensorTrainFit{ValueType})(x::Vector{ValueType}) where {ValueType
253253
end
254254

255255

256-
257256
function fulltensor(obj::TensorTrain{T,N})::Array{T} where {T,N}
258257
sitedims_ = sitedims(obj)
259258
localdims = collect(prod.(sitedims_))
260-
r = [obj(collect(Tuple(i))) for i in CartesianIndices(Tuple(localdims))]
259+
result::Matrix{T} = reshape(obj.sitetensors[1], localdims[1], :)
260+
leftdim = localdims[1]
261+
for l in 2:length(obj)
262+
nextmatrix = reshape(
263+
obj.sitetensors[l], size(obj.sitetensors[l], 1), localdims[l] * size(obj.sitetensors[l])[end])
264+
leftdim *= localdims[l]
265+
result = reshape(result * nextmatrix, leftdim, size(obj.sitetensors[l])[end])
266+
end
261267
returnsize = collect(Iterators.flatten(sitedims_))
262-
return reshape(r, returnsize...)
268+
return reshape(result, returnsize...)
263269
end

test/test_tensortrain.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,25 @@ using Optim
4343
end
4444

4545

46+
function _fulltensor(obj::TCI.TensorTrain{T,N})::Array{T} where {T,N}
47+
sitedims_ = TCI.sitedims(obj)
48+
localdims = collect(prod.(sitedims_))
49+
r = [obj(collect(Tuple(i))) for i in CartesianIndices(Tuple(localdims))]
50+
returnsize = collect(Iterators.flatten(sitedims_))
51+
return reshape(r, returnsize...)
52+
end
53+
54+
@testset "TT fulltensor" for T in [Float64, ComplexF64]
55+
linkdims = [1, 2, 3, 1]
56+
L = length(linkdims) - 1
57+
localdims = fill(4, L)
58+
tts = TCI.TensorTrain{T,3}([randn(T, linkdims[n], localdims[n], linkdims[n+1]) for n in 1:L])
59+
tto = TCI.TensorTrain{4}(tts, fill([2, 2], L))
60+
61+
@test _fulltensor(tts) TCI.fulltensor(tts)
62+
end
63+
64+
4665
@testset "TT shape conversion" for T in [Float64, ComplexF64]
4766
linkdims = [1, 2, 3, 1]
4867
L = length(linkdims) - 1

0 commit comments

Comments
 (0)