Skip to content

Commit 9bf423e

Browse files
committed
Merge branch 'main' of gitlab.com:tensors4fields/tensorcrossinterpolation.jl
2 parents 995f805 + 397a268 commit 9bf423e

File tree

4 files changed

+96
-5
lines changed

4 files changed

+96
-5
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorCrossInterpolation"
22
uuid = "b261b2ec-6378-4871-b32e-9173bb050604"
33
authors = ["Ritter.Marc <[email protected]>, Hiroshi Shinaoka <[email protected]> and contributors"]
4-
version = "0.9.0"
4+
version = "0.9.1"
55

66
[deps]
77
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
@@ -11,6 +11,7 @@ QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
1111
[compat]
1212
EllipsisNotation = "1"
1313
julia = "1.6"
14+
QuadGK = "2.9"
1415

1516
[extras]
1617
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"

src/tensortrain.jl

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,74 @@ function tensortrain(tci)
5656
return TensorTrain(tci)
5757
end
5858

59+
function _factorize(
60+
A::Matrix{V}, method::Symbol; tolerance::Float64, maxbonddim::Int
61+
) where {V}
62+
if method === :LU
63+
factorization = rrlu(A, abstol=tolerance, maxrank=maxbonddim)
64+
return left(factorization), right(factorization), npivots(factorization)
65+
elseif method === :CI
66+
factorization = MatrixLUCI(A, abstol=tolerance, maxrank=maxbonddim)
67+
return left(factorization), right(factorization), npivots(factorization)
68+
elseif method === :SVD
69+
factorization = LinearAlgebra.svd(A)
70+
trunci = min(
71+
replacenothing(findlast(>(tolerance), factorization.S), 1),
72+
maxbonddim
73+
)
74+
return (
75+
factorization.U[:, 1:trunci],
76+
Diagonal(factorization.S[1:trunci]) * factorization.Vt[1:trunci, :],
77+
trunci
78+
)
79+
else
80+
error("Not implemented yet.")
81+
end
82+
end
83+
84+
"""
85+
function compress!(
86+
tt::TensorTrain{V, N},
87+
method::Symbol=:LU;
88+
tolerance::Float64=1e-12,
89+
maxbonddim=typemax(Int)
90+
) where {V, N}
91+
92+
Compress the tensor train `tt` using `LU`, `CI` or `SVD` decompositions.
93+
"""
94+
function compress!(
95+
tt::TensorTrain{V, N},
96+
method::Symbol=:LU;
97+
tolerance::Float64=1e-12,
98+
maxbonddim=typemax(Int)
99+
) where {V, N}
100+
for ell in 1:length(tt)-1
101+
shapel = size(tt.sitetensors[ell])
102+
left, right, newbonddim = _factorize(
103+
reshape(tt.sitetensors[ell], prod(shapel[1:end-1]), shapel[end]),
104+
method; tolerance, maxbonddim
105+
)
106+
tt.sitetensors[ell] = reshape(left, shapel[1:end-1]..., newbonddim)
107+
shaper = size(tt.sitetensors[ell+1])
108+
nexttensor = right * reshape(tt.sitetensors[ell+1], shaper[1], prod(shaper[2:end]))
109+
tt.sitetensors[ell+1] = reshape(nexttensor, newbonddim, shaper[2:end]...)
110+
end
111+
112+
for ell in length(tt):-1:2
113+
shaper = size(tt.sitetensors[ell])
114+
left, right, newbonddim = _factorize(
115+
reshape(tt.sitetensors[ell], shaper[1], prod(shaper[2:end])),
116+
method; tolerance, maxbonddim
117+
)
118+
tt.sitetensors[ell] = reshape(right, newbonddim, shaper[2:end]...)
119+
shapel = size(tt.sitetensors[ell-1])
120+
nexttensor = reshape(tt.sitetensors[ell-1], prod(shapel[1:end-1]), shapel[end]) * left
121+
tt.sitetensors[ell-1] = reshape(nexttensor, shapel[1:end-1]..., newbonddim)
122+
end
123+
124+
nothing
125+
end
126+
59127

60128
"""
61129
Fitting data with a TensorTrain object.
@@ -83,7 +151,7 @@ function to_tensors(obj::TensorTrainFit{ValueType}, x::Vector{ValueType}) where
83151
reshape(
84152
x[obj.offsets[n]+1:obj.offsets[n+1]],
85153
size(obj.tt[n])
86-
)
154+
)
87155
for n in 1:length(obj.tt)
88156
]
89157
end
@@ -94,5 +162,5 @@ end
94162

95163
function (obj::TensorTrainFit{ValueType})(x::Vector{ValueType}) where {ValueType}
96164
tensors = to_tensors(obj, x)
97-
return sum((abs2(_evaluate(tensors, indexset) - obj.values[i]) for (i, indexset) in enumerate(obj.indexsets)))
165+
return sum((abs2(_evaluate(tensors, indexset) - obj.values[i]) for (i, indexset) in enumerate(obj.indexsets)))
98166
end

src/util.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,11 @@ function optfirstpivot(
107107

108108
return pivot
109109
end
110+
111+
function replacenothing(value::Union{T, Nothing}, default::T)::T where {T}
112+
if isnothing(value)
113+
return default
114+
else
115+
return value
116+
end
117+
end

test/test_tensortrain.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,33 @@ using Optim
77
g(v) = 1 / (sum(v .^ 2) + 1)
88
localdims = (6, 6, 6, 6)
99
tolerance = 1e-8
10+
allindices = CartesianIndices(localdims)
11+
1012
tci, ranks, errors = TCI.crossinterpolate1(Float64, g, localdims; tolerance=tolerance)
1113
tt = TCI.TensorTrain(tci)
12-
@test rank(tci) == rank(tt)
14+
@test TCI.rank(tci) == TCI.rank(tt)
1315
@test TCI.linkdims(tci) == TCI.linkdims(tt)
1416
gsum = 0.0
15-
for i in CartesianIndices(localdims)
17+
for i in allindices
1618
@test TCI.evaluate(tci, i) TCI.evaluate(tt, i)
1719
@test tt(i) == TCI.evaluate(tt, i)
1820
functionvalue = g(Tuple(i))
1921
@test abs(TCI.evaluate(tt, i) - functionvalue) < tolerance
2022
gsum += functionvalue
2123
end
2224
@test gsum TCI.sum(tt)
25+
26+
for method in [:LU, :CI, :SVD]
27+
ttcompressed = deepcopy(tt)
28+
TCI.compress!(ttcompressed, method; maxbonddim=5)
29+
@test TCI.rank(ttcompressed) <= 5
30+
end
31+
32+
for method in [:LU, :CI, :SVD]
33+
ttcompressed = deepcopy(tt)
34+
TCI.compress!(ttcompressed, method; tolerance=1.0)
35+
@test TCI.rank(ttcompressed) <= TCI.rank(tt)
36+
end
2337
end
2438

2539
@testset "batchevaluate" begin

0 commit comments

Comments
 (0)