Skip to content

Commit e5061b3

Browse files
committed
added multiplication operator
1 parent 91ff638 commit e5061b3

File tree

3 files changed

+35
-2
lines changed

3 files changed

+35
-2
lines changed

src/abstracttensortrain.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,6 @@ Addition of two tensor trains. If `c = a + b`, then `c(v) ≈ a(v) + b(v)` at ea
218218
219219
See also: [`add`](@ref)
220220
"""
221-
function (+)(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
221+
function Base.:+(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
222222
return add(lhs, rhs)
223223
end

src/tensortrain.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,36 @@ function compress!(
125125
end
126126

127127

128+
function multiply!(tt::TensorTrain{V, N}, a) where {V, N}
129+
tt.sitetensors[end] .= tt.sitetensors[end] .* a
130+
nothing
131+
end
132+
133+
function multiply!(a, tt::TensorTrain{V, N}) where {V, N}
134+
tt.sitetensors[end] .= a .* tt.sitetensors[end]
135+
nothing
136+
end
137+
138+
function multiply(tt::TensorTrain{V, N}, a)::TensorTrain{V, N} where {V, N}
139+
tt2 = deepcopy(tt)
140+
multiply!(tt2, a)
141+
return tt2
142+
end
143+
144+
function multiply(a, tt::TensorTrain{V, N})::TensorTrain{V, N} where {V, N}
145+
tt2 = deepcopy(tt)
146+
multiply!(a, tt2)
147+
return tt2
148+
end
149+
150+
function Base.:*(tt::TensorTrain{V, N}, a)::TensorTrain{V, N} where {V, N}
151+
return multiply(tt, a)
152+
end
153+
154+
function Base.:*(a, tt::TensorTrain{V, N})::TensorTrain{V, N} where {V, N}
155+
return multiply(a, tt)
156+
end
157+
128158
"""
129159
Fitting data with a TensorTrain object.
130160
This may be useful when the interpolated function is noisy.

test/test_tensortrain.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ end
119119
@test [TCI.evaluate(ttopt, idx) for (idx, v) in zip(indexsets, values)] values
120120
end
121121

122-
@testset "tensor train addition" for T in [Float64, ComplexF64]
122+
@testset "tensor train addition and multiplication" for T in [Float64, ComplexF64]
123123
Random.seed!(10)
124124
localdims = [2, 2, 2]
125125
linkdims = [1, 2, 3, 1]
@@ -134,6 +134,9 @@ end
134134
ttadd2 = tt1 + tt2
135135
@test ttadd2.(indices) [tt1(v) + tt2(v) for v in indices]
136136

137+
tt1mul = 1.6 * tt1
138+
@test tt1mul.(indices) 1.6 .* tt1.(indices)
139+
137140
ttshort = TCI.TensorTrain{T,3}([randn(T, linkdims[n], localdims[n], linkdims[n+1]) for n in 1:L-1])
138141
@test_throws DimensionMismatch TCI.add(tt1, ttshort)
139142

0 commit comments

Comments
 (0)