Skip to content

Commit b3960f7

Browse files
committed
Merge branch 'main' of github.com:tensorlayer/TensorLayerX into main
2 parents 94668f9 + e02ddf2 commit b3960f7

File tree

7 files changed

+165
-10
lines changed

7 files changed

+165
-10
lines changed

docs/modules/ops.rst

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,14 @@ API - Operations
113113
set_seed
114114
is_tensor
115115
tensor_scatter_nd_update
116+
scatter_update
116117
diag
117118
mask_select
118119
eye
119120
einsum
120121
set_device
122+
get_device
123+
to_device
121124

122125
TensorLayerX Tensor Operations
123126
--------------------------------
@@ -550,6 +553,10 @@ tensor_scatter_nd_update
550553
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
551554
.. autofunction:: tensor_scatter_nd_update
552555

556+
scatter_update
557+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
558+
.. autofunction:: scatter_update
559+
553560
diag
554561
^^^^^^^^^^^^^^^^^^^^^^^
555562
.. autofunction:: diag
@@ -568,4 +575,12 @@ einsum
568575

569576
set_device
570577
^^^^^^^^^^^^^^^^^^^^^^^
571-
.. autofunction:: set_device
578+
.. autofunction:: set_device
579+
580+
get_device
581+
^^^^^^^^^^^^^^^^^^^^^^^
582+
.. autofunction:: get_device
583+
584+
to_device
585+
^^^^^^^^^^^^^^^^^^^^^^^
586+
.. autofunction:: to_device

tensorlayerx/backend/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@
149149
from .ops import eye
150150
from .ops import einsum
151151
from .ops import set_device
152+
from .ops import get_device
153+
from .ops import scatter_update
154+
from .ops import to_device
152155
# dtype
153156
from .ops import (
154157
DType, float16, float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool, complex64,

tensorlayerx/backend/ops/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@
201201
from .load_backend import eye
202202
from .load_backend import einsum
203203
from .load_backend import set_device
204+
from .load_backend import get_device
205+
from .load_backend import scatter_update
206+
from .load_backend import to_device
204207
# dtype
205208
from .load_backend import (
206209
DType, float16, float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool, complex64,

tensorlayerx/backend/ops/mindspore_backend.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def set_context(**kwargs):
6868

6969

7070
def get_tensor_shape(x):
71-
return list(P.Shape()(x))
71+
return list(x.shape)
7272

7373

7474
# initializers
@@ -1844,4 +1844,24 @@ def set_device(device = 'GPU', id = 0):
18441844
if device not in ['GPU', 'CPU', 'Ascend']:
18451845
raise ValueError ("In mindspore, only support 'CPU', 'GPU' and 'Ascend'.")
18461846
ms.context.set_context(device_target=device)
1847-
ms.context.set_context(device_id = id)
1847+
ms.context.set_context(device_id = id)
1848+
1849+
def scatter_update(tensor, indices, updates):
1850+
if not isinstance(tensor, ms.Tensor) or not isinstance(updates, ms.Tensor):
1851+
raise TypeError("tensor, updates should be Tensor, but got tensor type is {}, "
1852+
"and updates type is {}.".format(type(tensor), type(updates)))
1853+
indices = ms.Tensor(indices)
1854+
shape = indices.shape
1855+
indices = ms.ops.reshape(indices, (shape[0], 1))
1856+
op = ms.ops.TensorScatterUpdate()
1857+
return op(tensor, indices, updates)
1858+
1859+
def get_device():
1860+
device = ms.context.get_context("device_target")
1861+
id = ms.context.get_context("device_id")
1862+
device = device + ":" +str(id)
1863+
return device
1864+
1865+
def to_device(tensor, device = 'GPU', id = 0):
1866+
1867+
return tensor

tensorlayerx/backend/ops/paddle_backend.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def set_context(**kwargs):
8383

8484

8585
def get_tensor_shape(x):
86-
return pd.shape(x)
86+
return list(x.shape)
8787

8888

8989
# initializers
@@ -1211,7 +1211,10 @@ def floor(x):
12111211

12121212

12131213
def gather(params, indices, axis=None):
1214-
1214+
if axis < 0:
1215+
axis = len(params.shape) + axis
1216+
if axis is None:
1217+
axis = 0
12151218
return pd.gather(params, indices, axis)
12161219

12171220

@@ -1859,6 +1862,21 @@ def __call__(self, *args):
18591862

18601863
def set_device(device = 'GPU', id = 0):
18611864
device = device.lower()
1862-
if device == 'GPU':
1865+
if device == 'gpu':
18631866
device = device + ':' + str(id)
1864-
paddle.device.set_device(device)
1867+
paddle.device.set_device(device)
1868+
1869+
def scatter_update(tensor, indices, updates):
1870+
1871+
return pd.scatter(tensor, indices, updates)
1872+
1873+
def get_device():
1874+
1875+
return paddle.device.get_device()
1876+
1877+
def to_device(tensor, device = 'GPU', id = 0):
1878+
device = device.upper()
1879+
if device == 'GPU':
1880+
return paddle.to_tensor(tensor, place=paddle.CUDAPlace(id))
1881+
if device == 'CPU':
1882+
return paddle.to_tensor(tensor, place=paddle.CPUPlace())

tensorlayerx/backend/ops/tensorflow_backend.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3890,4 +3890,75 @@ def set_device(device = 'GPU', id = 0):
38903890
tf.config.experimental.set_memory_growth(gpu, True)
38913891
tf.config.experimental.set_visible_devices(gpus[id], 'GPU')
38923892
except RuntimeError as e:
3893-
print(e)
3893+
print(e)
3894+
3895+
def scatter_update(tensor, indices, updates):
3896+
"""Applies sparse updates to a variable
3897+
3898+
Parameters
3899+
----------
3900+
tensor : Tensor
3901+
A Tensor. The dim of tensor must be 1.
3902+
indices : Tensor
3903+
Indices into the tensor.
3904+
updates : Tensor
3905+
Updated values
3906+
3907+
Returns
3908+
-------
3909+
Tensor after updated.
3910+
3911+
Examples
3912+
---------
3913+
>>> import tensorlayerx as tlx
3914+
>>> x = tlx.ops.ones((5,))
3915+
>>> indices = tlx.ops.convert_to_tensor([0, 4, 2])
3916+
>>> updates = tlx.ops.convert_to_tensor([1., 4., 7.])
3917+
>>> res = tlx.ops.scatter_update(x, indices, updates)
3918+
>>> [1. 1. 7. 1. 4.]
3919+
"""
3920+
shape = indices.shape
3921+
indices = tf.reshape(indices, shape = (shape[0], 1))
3922+
return tf.tensor_scatter_nd_update(tensor, indices, updates)
3923+
3924+
def get_device():
3925+
"""This function can get the specified global device.
3926+
3927+
Returns
3928+
-------
3929+
The global device.
3930+
3931+
Examples
3932+
---------
3933+
>>> import tensorlayerx as tlx
3934+
>>> x = tlx.ops.get_device()
3935+
>>> "CPU"
3936+
"""
3937+
device = tf.config.experimental.get_visible_devices('GPU')
3938+
if len(device) == 0:
3939+
device = tf.config.experimental.get_visible_devices('CPU')
3940+
return device
3941+
3942+
def to_device(tensor, device = 'GPU', id = 0):
3943+
"""Returns a copy of Tensor in specified device.
3944+
3945+
Parameters
3946+
----------
3947+
tensor : Tensor
3948+
A tensor.
3949+
device : str
3950+
The specified device. Support 'GPU' and 'CPU'. Default is 'GPU'.
3951+
id : int
3952+
The id of specified device. Default is 0.
3953+
3954+
3955+
Examples
3956+
---------
3957+
>>> import tensorlayerx as tlx
3958+
>>> x = tlx.ops.ones((5,))
3959+
>>> x = tlx.ops.to_device(x, device="GPU", id=0)
3960+
"""
3961+
if device is None:
3962+
return tensor
3963+
with tf.device("/" + device.upper()+':'+str(id)):
3964+
return tf.identity(tensor)

tensorlayerx/backend/ops/torch_backend.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1639,7 +1639,7 @@ def is_tensor(x):
16391639

16401640
def tensor_scatter_nd_update(tensor, indices, updates):
16411641
tensor = torch.tensor(tensor)
1642-
indices = torch.tensor(indices)
1642+
indices = torch.tensor(indices, dtype=torch.long)
16431643
updates = torch.tensor(updates)
16441644
indices = torch.flatten(indices)
16451645
tensor[indices] = updates
@@ -1666,6 +1666,8 @@ def mask_select(x, mask, axis = 0):
16661666
return x[:,:,:, mask]
16671667

16681668
def eye(n, m=None, dtype=None):
1669+
if m is None:
1670+
m = n
16691671
return torch.eye(n = n, m = m, dtype =dtype)
16701672

16711673

@@ -1684,4 +1686,27 @@ def __call__(self, *args):
16841686
def set_device(device = 'GPU', id = 0):
16851687
if device == 'GPU':
16861688
torch.set_default_tensor_type('torch.cuda.FloatTensor')
1687-
torch.cuda.set_device(id)
1689+
torch.cuda.set_device(id)
1690+
1691+
def scatter_update(tensor, indices, updates):
1692+
tensor = torch.tensor(tensor)
1693+
indices = torch.tensor(indices, dtype=torch.long)
1694+
updates = torch.tensor(updates)
1695+
tensor[indices] = updates
1696+
return tensor
1697+
1698+
def get_device():
1699+
try:
1700+
id = torch.cuda.current_device()
1701+
device = 'GPU:' + str(id)
1702+
return device
1703+
except:
1704+
device = 'CPU'
1705+
return device
1706+
1707+
def to_device(tensor, device='GPU', id=0):
1708+
device = device.lower()
1709+
if device == 'gpu':
1710+
device = 'cuda' + ':' + str(id)
1711+
tensor = tensor.detach().to(device)
1712+
return tensor

0 commit comments

Comments
 (0)