Skip to content

Commit 7f801a9

Browse files
committed
added divide function
1 parent e5061b3 commit 7f801a9

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

src/tensortrain.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,21 @@ function Base.:*(a, tt::TensorTrain{V, N})::TensorTrain{V, N} where {V, N}
155155
return multiply(a, tt)
156156
end
157157

158+
function divide!(tt::TensorTrain{V, N}, a) where {V, N}
159+
tt.sitetensors[end] .= tt.sitetensors[end] ./ a
160+
nothing
161+
end
162+
163+
function divide(tt::TensorTrain{V, N}, a) where {V, N}
164+
tt2 = deepcopy(tt)
165+
divide!(tt2, a)
166+
return tt2
167+
end
168+
169+
function Base.:/(tt::TensorTrain{V, N}, a) where {V, N}
170+
return divide(tt, a)
171+
end
172+
158173
"""
159174
Fitting data with a TensorTrain object.
160175
This may be useful when the interpolated function is noisy.

test/test_tensortrain.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ end
137137
tt1mul = 1.6 * tt1
138138
@test tt1mul.(indices) 1.6 .* tt1.(indices)
139139

140+
tt1div = tt1mul / 3.2
141+
@test tt1div.(indices) tt1.(indices) ./ 2.0
142+
140143
ttshort = TCI.TensorTrain{T,3}([randn(T, linkdims[n], localdims[n], linkdims[n+1]) for n in 1:L-1])
141144
@test_throws DimensionMismatch TCI.add(tt1, ttshort)
142145

0 commit comments

Comments
 (0)