Skip to content

Parallel TCI, Parallel zipup, Fit, Parallel fit and minor functions #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
*.jl.mem
/Manifest.toml
/docs/build/
LocalPreferences.toml
.vscode/settings.json
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.9.17"
BitIntegers = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

Expand All @@ -16,6 +17,7 @@ EllipsisNotation = "1"
QuadGK = "2.9"
Random = "1.10.0"
julia = "1.6"
MPI = "0.20.22"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand Down
5 changes: 5 additions & 0 deletions docs/src/documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,8 @@ Modules = [TensorCrossInterpolation]
Pages = ["cachedfunction.jl", "batcheval.jl", "util.jl", "globalsearch.jl"]
```

## Parallel utility
```@autodocs
Modules = [TensorCrossInterpolation]
Pages = ["mpi.jl"]
```
5 changes: 4 additions & 1 deletion src/TensorCrossInterpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ module TensorCrossInterpolation
using LinearAlgebra
using EllipsisNotation
using BitIntegers
import QuadGK
using MPI
using Base.Threads

import QuadGK
# To add a method for rank(tci)
import LinearAlgebra: rank, diag
import LinearAlgebra as LA
Expand Down Expand Up @@ -40,5 +42,6 @@ include("conversion.jl")
include("integration.jl")
include("contraction.jl")
include("globalsearch.jl")
include("mpi.jl")

end
228 changes: 224 additions & 4 deletions src/abstracttensortrain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,26 @@ function evaluate(
return only(prod(T[:, i, :] for (T, i) in zip(tt, indexset)))
end

"""
function evaluate(
tt::TensorTrain{V},
indexset::Union{AbstractVector{LocalIndex}, NTuple{N, LocalIndex}}
)::V where {N, V}

Evaluates the tensor train `tt` at indices given by `indexset` and `jndexset`. This is ment to be used for MPOs.
"""
function evaluate(
tt::AbstractTensorTrain{V},
indexset::Union{AbstractVector{LocalIndex},NTuple{N,LocalIndex}},
jndexset::Union{AbstractVector{LocalIndex},NTuple{N,LocalIndex}}
)::V where {N,V}
if length(indexset) != length(tt)
throw(ArgumentError("To evaluate a tt of length $(length(tt)), you have to provide $(length(tt)) indices, but there were $(length(indexset))."))
end
return only(prod(T[:, i, j, :] for (T, i, j) in zip(tt, indexset, jndexset)))
end


"""
function evaluate(tt::TensorTrain{V}, indexset::CartesianIndex) where {V}

Expand Down Expand Up @@ -175,6 +195,38 @@ function sum(tt::AbstractTensorTrain{V}) where {V}
return only(v)
end

"""
function average(tt::TensorTrain{V}) where {V}

Evaluates the average of the tensor train approximation over all lattice sites in an efficient
factorized manner.
"""
function average(tt::AbstractTensorTrain{V}) where {V}
v = transpose(sum(tt[1], dims=(1, 2))[1, 1, :]) / length(tt[1][1, :, 1])
for T in tt[2:end]
v *= sum(T, dims=2)[:, 1, :] / length(T[1, :, 1])
end
return only(v)
end

"""
function weightedsum(tt::TensorTrain{V}, w::Vector{V}) where {V}

Evaluates the weighted sum of the tensor train approximation over all lattice sites in an efficient
factorized manner, where w is the vector of vector of weights which has the same length and the same sizes as tt.
"""
function weightedsum(tt::AbstractTensorTrain{V}, w::Vector{Vector{V}}) where {V}
length(tt) == length(w) || throw(DimensionMismatch("The length of the Tensor Train is different from the one of the weight vector ($(length(tt)) and $(length(w)))."))
size(tt[1])[2] == length(w[1]) || throw(DimensionMismatch("The dimension at site 1 of the Tensor Train is different from the one of the weight vector ($(size(tt[1])[2]) and $(length(w[1])))."))
v = transpose(sum(tt[1].*w[1]', dims=(1, 2))[1, 1, :])
for i in 2:length(tt)
size(tt[i])[2] == length(w[i]) || throw(DimensionMismatch("The dimension at site $(i) of the Tensor Train is different from the one of the weight vector ($(size(tt[i])[2]) and $(length(w[i])))."))
v *= sum(tt[i].*w[i]', dims=2)[:, 1, :]
end
return only(v)
end


function _addtttensor(
A::Array{V}, B::Array{V};
factorA=one(V), factorB=one(V),
Expand Down Expand Up @@ -215,7 +267,7 @@ See also: [`+`](@ref)
function add(
lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V};
factorlhs=one(V), factorrhs=one(V),
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int)
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int), normalizeerror::Bool=true
) where {V}
if length(lhs) != length(rhs)
throw(DimensionMismatch("Two tensor trains with different length ($(length(lhs)) and $(length(rhs))) cannot be added elementwise."))
Expand All @@ -233,7 +285,7 @@ function add(
for ell in 1:L
]
)
compress!(tt, :SVD; tolerance, maxbonddim)
compress!(tt, :SVD; tolerance, maxbonddim, normalizeerror)
return tt
end

Expand All @@ -247,9 +299,9 @@ Subtract two tensor trains `lhs` and `rhs`. See [`add`](@ref).
"""
function subtract(
lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V};
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int)
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int), normalizeerror::Bool=true
) where {V}
return add(lhs, rhs; factorrhs=-1 * one(V), tolerance, maxbonddim)
return add(lhs, rhs; factorrhs=-1 * one(V), tolerance, maxbonddim, normalizeerror)
end

@doc raw"""
Expand All @@ -270,6 +322,174 @@ function Base.:-(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where
return subtract(lhs, rhs)
end

function leftcanonicalize!(tt::AbstractTensorTrain{ValueType}) where {ValueType}
n = length(tt) # Number of sites
for i in 1:n-1
Q, R = qr(reshape(tt[i], prod(size(tt[i])[1:end-1]), size(tt[i])[end]))
Q = Matrix(Q)

tt[i] = reshape(Q, size(tt[i])[1:end-1]..., size(Q, 2)) # New bond dimension after Q

tmptt = reshape(tt[i+1], size(R, 2), :) # Reshape next tensor
tmptt .= Matrix(R) * tmptt
tt[i+1] = reshape(tmptt, size(tt[i+1])...) # Reshape back
end
end

# This creates a TensorTrain which has every site right-canonical except the last
function rightcanonicalize!(tt::AbstractTensorTrain{ValueType}) where {ValueType}
n = length(tt) # Number of sites
for i in n:-1:2
# Reshape W_i into a matrix (merging right bond and physical indices)
W = tt[i]
χl, d1, d2, χr = size(W)
W_mat = reshape(W, χl, d1*d2*χr)

# Perform RQ decottsition: W_mat = R * Q
F = lq(reverse(W_mat, dims=1))
R, Q = reverse(F.L), reverse(Matrix(F.Q), dims=1) # https://discourse.julialang.org/t/rq-decomposition/112795/13

# Reshape Q back into the MPO tensor
tt[i] = reshape(Q, size(Q, 1), d1, d2, χr) # New bond dimension after Q

# Update the previous tt tensor by absorbing R
tmptt = reshape(tt[i-1], :, size(R, 1)) # Reshape previous tensor
tmptt .= tmptt * Matrix(R)
tt[i-1] = reshape(tmptt, size(tt[i-1], 1), d1, d2, size(tt[i-1], 4)) # Reshape back
end
end


function centercanonicalize!(tt::Vector{Array{ValueType, N}}, center::Int; old_center::Int=0) where {ValueType, N}
orthogonality = checkorthogonality(tt)
n = length(tt) # Number of sites

if count(==( :N ), orthogonality) == 1
old_center_ = findfirst(==( :N ), orthogonality)
if old_center_ == nothing # Useless, but help JET compiling
old_center_ = old_center
end
# println("Sto canonicalizzando centrando in $center. ho trovato il centro in $old_center_. Quindi flipperò: $(center < old_center_ ? [size(tt[i]) for i in center:old_center_] : [size(tt[i]) for i in old_center_:center])")
if old_center != 0 && old_center != old_center_
println("Warning! In centercanonicalize!() old_center has been set as $old_center, but the real old center is $old_center_")
end
elseif old_center == 0
old_center_ = 1
else
old_center_ = old_center
end
# LEFT
for i in old_center_:center-1
Q, R = qr(reshape(tt[i], prod(size(tt[i])[1:end-1]), size(tt[i])[end]))
Q = Matrix(Q)

tt[i] = reshape(Q, size(tt[i])[1:end-1]..., size(Q, 2)) # New bond dimension after Q

tmptt = reshape(tt[i+1], size(R, 2), :) # Reshape next tensor
tmptt = Matrix(R) * tmptt
tt[i+1] = reshape(tmptt, size(tt[i+1])...) # Reshape back
end
# RIGHT
if count(==( :N ), orthogonality) == 1
old_center_ = findfirst(==( :N ), orthogonality)
if old_center_ == nothing # Useless, but help JET compiling
old_center_ = old_center
end
if old_center != 0 && old_center != old_center_
println("Warning! In centercanonicalize!() old_center has been set as $old_center, but the real old center is $old_center_")
end
elseif old_center == 0
old_center_ = n
else
old_center_ = old_center
end
for i in old_center_:-1:center+1
W = tt[i]
χl, d1, d2, χr = size(W)
W_mat = reshape(W, χl, d1*d2*χr)

L, Q = lq(W_mat)
Q = Matrix(Q)
# Reshape Q back into the tt tensor
tt[i] = reshape(Q, size(Q, 1), d1, d2, χr) # New bond dimension after Q

# Update the previous tt tensor by absorbing L
tmptt = reshape(tt[i-1], :, size(L, 1)) # Reshape previous tensor
tmptt = tmptt * Matrix(L)
tt[i-1] = reshape(tmptt, size(tt[i-1], 1), d1, d2, size(tmptt, 2)) # Reshape back
end
end

function move_center_right!(tt, i)
A = tt[i]
d = size(A)
A_mat = reshape(A, prod(d[1:end-1]), d[end])
Q, R = qr(A_mat)
Q = Matrix(Q)
tt[i] = reshape(Q, d[1:end-1]..., size(Q, 2))

B = tt[i+1]
B_mat = reshape(B, size(R, 2), :)
B_mat .= Matrix(R) * B_mat
tt[i+1] = reshape(B_mat, size(B)...)
end

function move_center_left!(tt, i)
A = tt[i]
d = size(A)
A_mat = reshape(A, d[1], prod(d[2:end]))
L, Q = lq(A_mat)
Q = Matrix(Q)
tt[i] = reshape(Q, size(Q,1), d[2:end]...)

B = tt[i-1]
B_mat = reshape(B, :, size(L,1))
B_mat .= B_mat * Matrix(L)
tt[i-1] = reshape(B_mat, size(B)[1:3]..., size(L,1))
end


function leftcanonicalize(tt::AbstractTensorTrain{ValueType}) where {ValueType}
tt_ = deepcopy(tt)
leftcanonicalize!(tt_)
return tt_
end

# This creates a TensorTrain which has every site right-canonical except the last
function rightcanonicalize(tt::AbstractTensorTrain{ValueType}) where {ValueType}
tt_ = deepcopy(tt)
rightcanonicalize!(tt_)
return tt_
end

# This creates a TensorTrain which has every site right-canonical except the last
function centercanonicalize(tt::Vector{Array{ValueType, N}}, center::Int; old_center::Int=0) where {ValueType, N}
tt_ = deepcopy(tt)
centercanonicalize!(tt_, center; old_center)
return tt_
end

function checkorthogonality(tt::Vector{Array{ValueType, N}}) where {ValueType, N}
ort = Vector{Symbol}(undef, length(tt))
for i in 1:length(tt)
W = tt[i]
left_check = _contract(permutedims(W, (4,2,3,1,)), W, (2,3,4,),(2,3,1))
right_check = _contract(W, permutedims(W, (4,2,3,1,)), (2,3,4,),(2,3,1))
is_left = isapprox(left_check, I, atol=1e-7)
is_right = isapprox(right_check, I, atol=1e-7)
ort[i] = if is_left && is_right
:O # Orthogonal
elseif is_left
:L # Left orthogonal
elseif is_right
:R # Right orthogonal
else
:N # Non orthogonal
end
end
return ort
end

"""
Squared Frobenius norm of a tensor train.
"""
Expand Down
Loading
Loading