Skip to content

Commit 5006403

Browse files
committed
add scatter_nd
1 parent ed56fb9 commit 5006403

File tree

5 files changed

+16
-0
lines changed

5 files changed

+16
-0
lines changed

tensorlayerx/backend/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
from .load_backend import cast
105105
from .load_backend import transpose
106106
from .load_backend import gather_nd
107+
from .load_backend import scatter_nd
107108
from .load_backend import clip_by_value
108109
from .load_backend import split
109110
from .load_backend import get_tensor_shape

tensorlayerx/backend/ops/mindspore_backend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,10 @@ def gather_nd(params, indices, batch_dims=0):
10811081
return op(params, indices)
10821082

10831083

1084+
def scatter_nd(indices, updates, shape):
1085+
raise NotImplementedError
1086+
1087+
10841088
class ClipGradByValue(object):
10851089
def __init__(self, clip_min=-1, clip_max=1):
10861090
self.min = ms.Tensor(clip_min)

tensorlayerx/backend/ops/paddle_backend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,10 @@ def gather_nd(params, indices, batch_dims=0):
849849
return pd.gather_nd(params, indices)
850850

851851

852+
def scatter_nd(indices, updates, shape):
853+
raise NotImplementedError
854+
855+
852856
class ClipGradByValue(pd.nn.ClipGradByValue):
853857
def __init__(self, clip_min=-1, clip_max=1):
854858
super().__init__(max=clip_max, min=clip_min)

tensorlayerx/backend/ops/tensorflow_backend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,10 @@ def gather_nd(params, indices, batch_dims=0):
11181118
return tf.gather_nd(params, indices, batch_dims)
11191119

11201120

1121+
def scatter_nd(indices, updates, shape):
1122+
return tf.scatter_nd(indices, updates, shape)
1123+
1124+
11211125
class ClipGradByValue(object):
11221126
def __init__(self, clip_min=-1, clip_max=1):
11231127
self.min = clip_min

tensorlayerx/backend/ops/torch_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,9 @@ def gather_nd(params, indices, batch_dims=0):
910910
out = torch.take(params, idx)
911911
return out.view(out_shape)
912912

913+
def scatter_nd(indices, updates, shape):
914+
raise NotImplementedError
915+
913916

914917
class ClipGradByValue(object):
915918
def __init__(self, clip_min=-1, clip_max=1):

0 commit comments

Comments
 (0)