Skip to content

Commit 31893ac

Browse files
committed
bugfix: compress! function if no singular value is > tolerance
1 parent 76e76d4 commit 31893ac

File tree

2 files changed

+52
-23
lines changed

2 files changed

+52
-23
lines changed

src/tensortrain.jl

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -56,40 +56,61 @@ function tensortrain(tci)
5656
return TensorTrain(tci)
5757
end
5858

59+
function _factorize(
60+
A::Matrix{V}, method::Symbol; tolerance::Float64, maxbonddim::Int
61+
) where {V}
62+
if method === :LU
63+
factorization = rrlu(A, abstol=tolerance, maxrank=maxbonddim)
64+
return left(factorization), right(factorization), npivots(factorization)
65+
elseif method === :CI
66+
factorization = MatrixLUCI(A, abstol=tolerance, maxrank=maxbonddim)
67+
return left(factorization), right(factorization), npivots(factorization)
68+
elseif method === :SVD
69+
factorization = LinearAlgebra.svd(A)
70+
trunci = min(
71+
replacenothing(findlast(>(tolerance), factorization.S), 1),
72+
maxbonddim
73+
)
74+
return (
75+
factorization.U[:, 1:trunci],
76+
Diagonal(factorization.S[1:trunci]) * factorization.Vt[1:trunci, :],
77+
trunci
78+
)
79+
else
80+
error("Not implemented yet.")
81+
end
82+
end
83+
5984
function compress!(
60-
tt::AbstractTensorTrain{V},
85+
tt::TensorTrain{V, N},
6186
method::Symbol=:LU;
6287
tolerance::Float64=1e-12,
6388
maxbonddim=typemax(Int)
64-
) where {V}
65-
function factorize(A::Matrix{V})
66-
if method === :LU
67-
factorization = rrlu(A, abstol=tolerance, maxrank=maxbonddim)
68-
return left(factorization), right(factorization), npivots(factorization)
69-
elseif method === :CI
70-
factorization = MatrixLUCI(A, abstol=tolerance, maxrank=maxbonddim)
71-
return left(factorization), right(factorization), npivots(factorization)
72-
elseif method === :SVD
73-
factorization = LinearAlgebra.svd(A)
74-
trunci = min(findlast(>(tolerance), factorization.S), maxbonddim)
75-
return (
76-
factorization.U[:, 1:trunci],
77-
Diagonal(factorization.S[1:trunci]) * factorization.Vt[1:trunci, :],
78-
trunci
79-
)
80-
else
81-
error("Not implemented yet.")
82-
end
83-
end
84-
89+
) where {V, N}
8590
for ell in 1:length(tt)-1
8691
shapel = size(tt.sitetensors[ell])
87-
left, right, newbonddim = factorize(reshape(tt.sitetensors[ell], prod(shapel[1:end-1]), shapel[end]))
92+
left, right, newbonddim = _factorize(
93+
reshape(tt.sitetensors[ell], prod(shapel[1:end-1]), shapel[end]),
94+
method; tolerance, maxbonddim
95+
)
8896
tt.sitetensors[ell] = reshape(left, shapel[1:end-1]..., newbonddim)
8997
shaper = size(tt.sitetensors[ell+1])
9098
nexttensor = right * reshape(tt.sitetensors[ell+1], shaper[1], prod(shaper[2:end]))
9199
tt.sitetensors[ell+1] = reshape(nexttensor, newbonddim, shaper[2:end]...)
92100
end
101+
102+
for ell in length(tt):-1:2
103+
shaper = size(tt.sitetensors[ell])
104+
left, right, newbonddim = _factorize(
105+
reshape(tt.sitetensors[ell], shaper[1], prod(shaper[2:end])),
106+
method; tolerance, maxbonddim
107+
)
108+
tt.sitetensors[ell] = reshape(right, newbonddim, shaper[2:end]...)
109+
shapel = size(tt.sitetensors[ell-1])
110+
nexttensor = reshape(tt.sitetensors[ell-1], prod(shapel[1:end-1]), shapel[end]) * left
111+
tt.sitetensors[ell-1] = reshape(nexttensor, shapel[1:end-1]..., newbonddim)
112+
end
113+
93114
nothing
94115
end
95116

src/util.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,11 @@ function optfirstpivot(
107107

108108
return pivot
109109
end
110+
111+
function replacenothing(value::Union{T, Nothing}, default::T)::T where {T}
112+
if isnothing(value)
113+
return default
114+
else
115+
return value
116+
end
117+
end

0 commit comments

Comments
 (0)