Skip to content

Commit a8c09e9

Browse files
committed
Make interface of contract allow type conversion
1 parent c618386 commit a8c09e9

File tree

4 files changed

+61
-21
lines changed

4 files changed

+61
-21
lines changed

src/contraction.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -487,14 +487,14 @@ end
487487

488488
"""
489489
function contract(
490-
A::TensorTrain{ValueType,4},
491-
B::TensorTrain{ValueType,4};
490+
A::TensorTrain{V1,4},
491+
B::TensorTrain{V2,4};
492492
algorithm::Symbol=:TCI,
493493
tolerance::Float64=1e-12,
494494
maxbonddim::Int=typemax(Int),
495495
f::Union{Nothing,Function}=nothing,
496496
kwargs...
497-
) where {ValueType}
497+
) where {V1,V2}
498498
499499
Contract two tensor trains `A` and `B`.
500500
@@ -513,45 +513,48 @@ Arguments:
513513
- `kwargs...` are forwarded to [`crossinterpolate2`](@ref) if `algorithm=:TCI`.
514514
"""
515515
function contract(
516-
A::TensorTrain{ValueType,4},
517-
B::TensorTrain{ValueType,4};
516+
A::TensorTrain{V1,4},
517+
B::TensorTrain{V2,4};
518518
algorithm::Symbol=:TCI,
519519
tolerance::Float64=1e-12,
520520
maxbonddim::Int=typemax(Int),
521521
f::Union{Nothing,Function}=nothing,
522522
kwargs...
523-
) where {ValueType}
523+
)::TensorTrain{promote_type(V1,V2),4} where {V1,V2}
524+
Vres = promote_type(V1, V2)
525+
A_ = TensorTrain{Vres,4}(A)
526+
B_ = TensorTrain{Vres,4}(B)
524527
if algorithm === :TCI
525-
return contract_TCI(A, B; tolerance=tolerance, maxbonddim=maxbonddim, f=f, kwargs...)
528+
return contract_TCI(A_, B_; tolerance=tolerance, maxbonddim=maxbonddim, f=f, kwargs...)
526529
elseif algorithm === :naive
527530
if f !== nothing
528531
error("Naive contraction implementation cannot contract matrix product with a function. Use algorithm=:TCI instead.")
529532
end
530-
return contract_naive(A, B; tolerance=tolerance, maxbonddim=maxbonddim)
533+
return contract_naive(A_, B_; tolerance=tolerance, maxbonddim=maxbonddim)
531534
elseif algorithm === :zipup
532535
if f !== nothing
533536
error("Zipup contraction implementation cannot contract matrix product with a function. Use algorithm=:TCI instead.")
534537
end
535-
return contract_zipup(A, B; tolerance, maxbonddim)
538+
return contract_zipup(A_, B_; tolerance, maxbonddim)
536539
else
537540
throw(ArgumentError("Unknown algorithm $algorithm."))
538541
end
539542
end
540543

541544
function contract(
542545
A::Union{TensorCI1{V},TensorCI2{V},TensorTrain{V,3}},
543-
B::TensorTrain{V,4};
546+
B::TensorTrain{V2,4};
544547
kwargs...
545-
)::TensorTrain{V,3} where {V}
548+
)::TensorTrain{promote_type(V,V2),3} where {V,V2}
546549
tt = contract(TensorTrain{4}(A, [(1, s...) for s in sitedims(A)]), B; kwargs...)
547550
return TensorTrain{3}(tt, prod.(sitedims(tt)))
548551
end
549552

550553
function contract(
551554
A::TensorTrain{V,4},
552-
B::Union{TensorCI1{V},TensorCI2{V},TensorTrain{V,3}};
555+
B::Union{TensorCI1{V2},TensorCI2{V2},TensorTrain{V2,3}};
553556
kwargs...
554-
)::TensorTrain{V,3} where {V}
557+
)::TensorTrain{promote_type(V,V2),3} where {V,V2}
555558
tt = contract(A, TensorTrain{4}(B, [(s..., 1) for s in sitedims(B)]); kwargs...)
556559
return TensorTrain{3}(tt, prod.(sitedims(tt)))
557560
end

src/tensortrain.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ mutable struct TensorTrain{ValueType,N} <: AbstractTensorTrain{ValueType}
3030
end
3131
end
3232

33+
function Base.show(io::IO, obj::TensorTrain{V,N}) where {V,N}
34+
print(
35+
io,
36+
"$(typeof(obj)) of rank $(maximum(linkdims(obj)))"
37+
)
38+
end
39+
40+
function TensorTrain{V2,N}(tt::TensorTrain{V})::TensorTrain{V2,N} where {V,V2,N}
41+
return TensorTrain{V2,N}(Array{V2}.(sitetensors(tt)))
42+
end
43+
3344
"""
3445
function TensorTrain(sitetensors::Vector{Array{V, 3}}) where {V}
3546
@@ -53,7 +64,7 @@ function TensorTrain(tci::AbstractTensorTrain{V})::TensorTrain{V,3} where {V}
5364
end
5465

5566
"""
56-
function TensorTrain{N}(tci::AbstractTensorTrain{V}) where {V,N}
67+
function TensorTrain{V2,N}(tci::AbstractTensorTrain{V}) where {V,V2,N}
5768
5869
Convert a tensor-train-like object into a tensor train.
5970
@@ -62,15 +73,15 @@ Arguments:
6273
- `localdims`: a vector of local dimensions for each tensor in the tensor train. A each element
6374
of `localdims` should be an array-like object of `N-2` integers.
6475
"""
65-
function TensorTrain{V,N}(tt::AbstractTensorTrain{V}, localdims)::TensorTrain{V,N} where {V,N}
76+
function TensorTrain{V2,N}(tt::AbstractTensorTrain{V}, localdims)::TensorTrain{V2,N} where {V,V2,N}
6677
for d in localdims
6778
length(d) == N - 2 || error("Each element of localdims be a list of N-2 integers.")
6879
end
6980
for n in 1:length(tt)
7081
prod(size(tt[n])[2:end-1]) == prod(localdims[n]) || error("The local dimensions at n=$n must match the tensor sizes.")
7182
end
72-
return TensorTrain{V,N}(
73-
[reshape(t, size(t, 1), localdims[n]..., size(t)[end]) for (n, t) in enumerate(sitetensors(tt))])
83+
return TensorTrain{V2,N}(
84+
[reshape(Array{V2}(t), size(t, 1), localdims[n]..., size(t)[end]) for (n, t) in enumerate(sitetensors(tt))])
7485
end
7586

7687
function TensorTrain{N}(tt::AbstractTensorTrain{V}, localdims)::TensorTrain{V,N} where {V,N}
@@ -239,4 +250,14 @@ end
239250
function (obj::TensorTrainFit{ValueType})(x::Vector{ValueType}) where {ValueType}
240251
tensors = to_tensors(obj, x)
241252
return sum((abs2(_evaluate(tensors, indexset) - obj.values[i]) for (i, indexset) in enumerate(obj.indexsets)))
253+
end
254+
255+
256+
257+
function fulltensor(obj::TensorTrain{T,N})::Array{T} where {T,N}
258+
sitedims_ = sitedims(obj)
259+
localdims = collect(prod.(sitedims_))
260+
r = [obj(collect(Tuple(i))) for i in CartesianIndices(Tuple(localdims))]
261+
returnsize = collect(Iterators.flatten(sitedims_))
262+
return reshape(r, returnsize...)
242263
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ include("test_tensortrain.jl")
2020
include("test_conversion.jl")
2121
include("test_contraction.jl")
2222
include("test_integration.jl")
23-
include("test_globalsearch.jl")
23+
include("test_globalsearch.jl")

test/test_tensortrain.jl

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ end
4848
L = length(linkdims) - 1
4949
localdims = fill(4, L)
5050
tts = TCI.TensorTrain{T,3}([randn(T, linkdims[n], localdims[n], linkdims[n+1]) for n in 1:L])
51-
tto = TCI.TensorTrain{4}(tts, fill([2,2], L))
51+
tto = TCI.TensorTrain{4}(tts, fill([2, 2], L))
5252
tts_reconst = TCI.TensorTrain{3}(tto, localdims)
5353

5454
for n in 1:L
5555
@test all(tts[n] .== tts_reconst[n])
5656
end
5757

58-
@test_throws ErrorException TCI.TensorTrain{4}(tts, fill([2,3], L)) # Wrong shape
59-
@test_throws ErrorException TCI.TensorTrain{4}(tts, fill([1,2,3], L)) # Wrong shape
58+
@test_throws ErrorException TCI.TensorTrain{4}(tts, fill([2, 3], L)) # Wrong shape
59+
@test_throws ErrorException TCI.TensorTrain{4}(tts, fill([1, 2, 3], L)) # Wrong shape
6060
end
6161

6262
@testset "batchevaluate" begin
@@ -176,3 +176,19 @@ end
176176
indicesmultileg = @. collect(zip(indices, indices))
177177
@test ttmultileg2.(indicesmultileg) 2 .* ttmultileg.(indicesmultileg)
178178
end
179+
180+
@testset "tensor train cast" begin
181+
Random.seed!(10)
182+
localdims = [2, 2, 2]
183+
linkdims_ = [1, 2, 3, 1]
184+
L = length(localdims)
185+
186+
tt1 = TCI.TensorTrain{Float64,3}([randn(Float64, linkdims_[n], localdims[n], linkdims_[n+1]) for n in 1:L])
187+
188+
tt2 = TCI.TensorTrain{ComplexF64,3}(tt1, localdims)
189+
@test TCI.fulltensor(tt1) TCI.fulltensor(tt2)
190+
191+
tt3 = TCI.TensorTrain{Float64,3}(tt2, localdims)
192+
@test TCI.fulltensor(tt1) TCI.fulltensor(tt3)
193+
end
194+

0 commit comments

Comments
 (0)