Skip to content

Commit 302a7ae

Browse files
committed
update ops
1 parent 1ae2442 commit 302a7ae

File tree

7 files changed

+80
-9
lines changed

7 files changed

+80
-9
lines changed

docs/modules/ops.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ API - Operations
112112
unsorted_segment_max
113113
set_seed
114114
is_tensor
115+
tensor_scatter_nd_update
116+
diag
115117

116118
TensorLayerX Tensor Operations
117119
--------------------------------
@@ -538,4 +540,12 @@ set_seed
538540

539541
is_tensor
540542
^^^^^^^^^^^^^^^^^^^^^^^
541-
.. autofunction:: is_tensor
543+
.. autofunction:: is_tensor
544+
545+
tensor_scatter_nd_update
546+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
547+
.. autofunction:: tensor_scatter_nd_update
548+
549+
diag
550+
^^^^^^^^^^^^^^^^^^^^^^^
551+
.. autofunction:: diag

tensorlayerx/backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@
143143
from .ops import unsorted_segment_max
144144
from .ops import set_seed
145145
from .ops import is_tensor
146+
from .ops import tensor_scatter_nd_update
147+
from .ops import diag
146148
# dtype
147149
from .ops import (
148150
DType, float16, float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool, complex64,

tensorlayerx/backend/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@
195195
from .load_backend import unsorted_segment_max
196196
from .load_backend import set_seed
197197
from .load_backend import is_tensor
198+
from .load_backend import tensor_scatter_nd_update
199+
from .load_backend import diag
198200
# dtype
199201
from .load_backend import (
200202
DType, float16, float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool, complex64,

tensorlayerx/backend/ops/mindspore_backend.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,6 +1818,13 @@ def is_tensor(x):
18181818
return isinstance(x, ms.Tensor)
18191819

18201820
def tensor_scatter_nd_update(tensor, indices, updates):
1821-
1821+
if not isinstance(tensor, ms.Tensor) or not isinstance(updates, ms.Tensor):
1822+
raise TypeError("tensor, updates should be Tensor, but got tensor type is {}, "
1823+
"and updates type is {}.".format(type(tensor), type(updates)))
1824+
indices = ms.Tensor(indices)
18221825
op = ms.ops.TensorScatterUpdate()
1823-
return op(tensor, indices, updates)
1826+
return op(tensor, indices, updates)
1827+
1828+
def diag(input, diagonal=0):
1829+
1830+
return ms.numpy.diag(input, diagonal)

tensorlayerx/backend/ops/paddle_backend.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1531,8 +1531,15 @@ def is_tensor(x):
15311531

15321532

15331533
def tensor_scatter_nd_update(tensor, indices, updates):
1534+
tensor = paddle.to_tensor(tensor)
1535+
indices = paddle.to_tensor(indices)
1536+
updates = paddle.to_tensor(updates)
15341537
a = pd.scatter_nd(indices, pd.ones_like(updates), tensor.shape)
15351538
a = pd.multiply(tensor, -a)
15361539
tensor = tensor + a
15371540
x = pd.scatter_nd_add(tensor, indices, updates)
1538-
return x
1541+
return x
1542+
1543+
def diag(input, diagonal=0):
1544+
1545+
return paddle.diag(input, diagonal)

tensorlayerx/backend/ops/tensorflow_backend.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3605,13 +3605,52 @@ def tensor_scatter_nd_update(tensor, indices, updates):
36053605
36063606
Parameters
36073607
----------
3608-
tensor
3609-
indices
3610-
updates
3608+
tensor : Tensor
3609+
tensor to update.
3610+
indices : list
3611+
indices to update.
3612+
updates : Tensor
3613+
value to apply at the indices
36113614
36123615
Returns
36133616
-------
3617+
updated Tensor.
36143618
3619+
Examples
3620+
---------
3621+
>>> import tensorlayerx as tlx
3622+
>>> tensor = tlx.ops.ones(shape=(5,3))
3623+
>>> indices = [[0],[ 4],[ 2]]
3624+
>>> updates = tlx.ops.convert_to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
3625+
>>> new_tensor = tlx.ops.tensor_scatter_nd_update(tensor, indices, updates)
3626+
>>> [[1. 2. 3.]
3627+
>>> [1. 1. 1.]
3628+
>>> [7. 8. 9.]
3629+
>>> [1. 1. 1.]
3630+
>>> [4. 5. 6.]]
36153631
"""
36163632

3617-
return tf.tensor_scatter_nd_update(tensor, indices, updates)
3633+
return tf.tensor_scatter_nd_update(tensor, indices, updates)
3634+
3635+
def diag(input, diagonal=0):
3636+
"""
3637+
3638+
Parameters
3639+
----------
3640+
input : Tensor
3641+
the input tensor.
3642+
diagonal : int
3643+
the diagonal to consider. Defualt is 0.
3644+
3645+
Returns
3646+
-------
3647+
the output tensor.
3648+
3649+
Examples
3650+
---------
3651+
>>> import tensorlayerx as tlx
3652+
>>> tensor = tlx.ops.convert_to_tensor([[1,2,3],[4,5,6],[7,8,9]]))
3653+
>>> new_tensor = tlx.ops.diag(tensor)
3654+
>>> [1, 5, 9]
3655+
"""
3656+
return tf.experimental.numpy.diag(input, diagonal)

tensorlayerx/backend/ops/torch_backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1633,9 +1633,13 @@ def is_tensor(x):
16331633
return isinstance(x, torch.Tensor)
16341634

16351635
def tensor_scatter_nd_update(tensor, indices, updates):
1636-
1636+
tensor = torch.tensor(tensor)
1637+
indices = torch.tensor(indices)
1638+
updates = torch.tensor(updates)
16371639
indices = torch.flatten(indices)
16381640
tensor[indices] = updates
16391641
return tensor
16401642

1643+
def diag(input, diagonal=0):
16411644

1645+
return torch.diag(input, diagonal)

0 commit comments

Comments
 (0)