Skip to content

Commit c5eb1d4

Browse files
committed
fixed conversion TCI2 -> TCI1
1 parent cac998d commit c5eb1d4

File tree

5 files changed

+17
-17
lines changed

5 files changed

+17
-17
lines changed

src/abstracttensortrain.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,11 @@ function _addtttensor(
191191
end
192192

193193
@doc raw"""
194-
function add(
195-
lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V};
196-
factorlhs=one(V), factorrhs=one(V),
197-
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int)
198-
) where {V}
194+
function add(
195+
lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V};
196+
factorlhs=one(V), factorrhs=one(V),
197+
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int)
198+
) where {V}
199199
200200
Addition of two tensor trains. If `C = add(A, B)`, then `C(v) ≈ A(v) + B(v)` at each index set `v`. Note that this function increases the bond dimension, i.e. ``\chi_{\text{result}} = \chi_1 + \chi_2`` if the original tensor trains had bond dimensions ``\chi_1`` and ``\chi_2``.
201201

src/conversion.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function TensorCI1{ValueType}(
2626
kwargs...
2727
) where {ValueType}
2828
L = length(tci2)
29-
tci1 = TensorCI1{ValueType}(length.(tci2.localset))
29+
tci1 = TensorCI1{ValueType}(tci2.localdims)
3030
tci1.Iset = IndexSet.(tci2.Iset)
3131
tci1.Jset = IndexSet.(tci2.Jset)
3232
tci1.PiIset = getPiIset.(Ref(tci1), 1:L)
@@ -51,22 +51,21 @@ function TensorCI1{ValueType}(
5151
end
5252
tci1.P[end] = ones(ValueType, 1, 1)
5353

54-
tci1.pivoterrors = max.(tci2.bonderrorsforward, tci2.bonderrorsbackward)
54+
tci1.pivoterrors = tci2.bonderrors
5555
tci1.maxsamplevalue = tci2.maxsamplevalue
5656
return tci1
5757
end
5858

5959
function TensorCI2{ValueType}(tci1::TensorCI1{ValueType}) where {ValueType}
60-
tci2 = TensorCI2{ValueType}(vcat(sitedims(tci1)...)::Vector{Int})
60+
tci2 = TensorCI2{ValueType}(vcat(collect.(sitedims(tci1))...)::Vector{Int})
6161
tci2.Iset = [i.fromint for i in tci1.Iset]
6262
tci2.Jset = [j.fromint for j in tci1.Jset]
63-
tci2.localset = tci1.localset
63+
tci2.localdims = tci1.localdims
6464
L = length(tci1)
6565
tci2.sitetensors[1:L-1] = TtimesPinv.(tci1, 1:L-1)
6666
tci2.sitetensors[end] = tci1.T[end]
6767
tci2.pivoterrors = Float64[]
68-
tci2.bonderrorsforward = tci1.pivoterrors
69-
tci2.bonderrorsbackward = tci1.pivoterrors
68+
tci2.bonderrors = tci1.pivoterrors
7069
tci2.maxsamplevalue = tci1.maxsamplevalue
7170
return tci2
7271
end

src/tensorci1.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,11 @@ function linkdim(tci::TensorCI1{V}, i::Int) where {V}
159159
end
160160

161161
function sitedims(tci::TensorCI1{V}) where {V}
162-
return [size(T)[2:end-1] for T in tci.T]
162+
return [collect(size(T)[2:end-1]) for T in tci.T]
163163
end
164164

165165
function sitedim(tci::TensorCI1{V}, i::Int) where {V}
166-
return size(tci.T[i])[2:end-1]
166+
return collect(size(tci.T[i])[2:end-1])
167167
end
168168

169169
"""

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ include("test_batcheval.jl")
1616
include("test_tensorci1.jl")
1717
include("test_tensorci2.jl")
1818
include("test_tensortrain.jl")
19+
include("test_conversion.jl")
1920
include("test_contraction.jl")
20-
include("test_integration.jl")
21+
include("test_integration.jl")

test/test_conversion.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Test
22
import TensorCrossInterpolation: TensorCI1, TensorCI2, sitedims, linkdims, rank,
3-
addglobalpivot!, crossinterpolate1, optimize!, MatrixACA, rrlu,
4-
nrows, ncols, evaluate, left, right
3+
addglobalpivot!, crossinterpolate1, crossinterpolate2, optimize!, MatrixACA, rrlu,
4+
nrows, ncols, evaluate, left, right, pivoterror, tensortrain
55

66
@testset "Conversion between rrLU and ACA" begin
77
A = [
@@ -31,7 +31,7 @@ end
3131
@test rank(tci2) == 0
3232
@test all(isempty.(tci2.Iset))
3333
@test all(isempty.(tci2.Jset))
34-
@test all(isempty.(tci2.T))
34+
@test all(isempty.(tci2.sitetensors))
3535

3636
globalpivot = [2, 2, 3, 1]
3737
tci1 = TensorCI1{ComplexF64}(rand, fill(d, n), globalpivot)

0 commit comments

Comments
 (0)