Skip to content

Commit 320e639

Browse files
committed
update ops and parameter container
1 parent 2a730fb commit 320e639

File tree

15 files changed

+928
-26
lines changed

15 files changed

+928
-26
lines changed

docs/modules/nn.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ Layer list
1616
Sequential
1717
ModuleList
1818
ModuleDict
19+
Parameter
20+
ParameterList
21+
ParameterDict
1922

2023
Input
2124

@@ -139,6 +142,18 @@ ModuleDict
139142
^^^^^^^^^^^^^^^^
140143
.. autoclass:: ModuleDict
141144

145+
Parameter
146+
^^^^^^^^^^^^^^^^
147+
.. autofunction:: Parameter
148+
149+
ParameterList
150+
^^^^^^^^^^^^^^^^
151+
.. autoclass:: ParameterList
152+
153+
ParameterDict
154+
^^^^^^^^^^^^^^^^
155+
.. autoclass:: ParameterDict
156+
142157
.. -----------------------------------------------------------
143158
.. Input Layer
144159
.. -----------------------------------------------------------

docs/modules/ops.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ API - Operations
117117
mask_select
118118
eye
119119
einsum
120+
set_device
120121

121122
TensorLayerX Tensor Operations
122123
--------------------------------
@@ -563,4 +564,8 @@ eye
563564

564565
einsum
565566
^^^^^^^^^^^^^^^^^^^^^^^
566-
.. autofunction:: einsum
567+
.. autofunction:: einsum
568+
569+
set_device
570+
^^^^^^^^^^^^^^^^^^^^^^^
571+
.. autofunction:: set_device
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os
2+
os.environ['TL_BACKEND'] = 'tensorflow'
3+
# os.environ['TL_BACKEND'] = 'mindspore'
4+
# os.environ['TL_BACKEND'] = 'paddle'
5+
# os.environ['TL_BACKEND'] = 'torch'
6+
7+
import tensorlayerx as tlx
8+
from tensorlayerx.nn import Module, Parameter, ParameterList, ParameterDict
9+
tlx.set_device(device='CPU', id = 0)
10+
11+
class MyModule(Module):
12+
def __init__(self):
13+
super(MyModule, self).__init__()
14+
self.params1 = ParameterDict({
15+
'left': Parameter(tlx.ones((5, 10))),
16+
'right': Parameter(tlx.zeros((5, 10)))
17+
})
18+
19+
self.params2 = ParameterList(
20+
[Parameter(tlx.ones((10,5))), Parameter(tlx.ones((5,10)))]
21+
)
22+
23+
def forward(self, x, choice):
24+
x = tlx.matmul(x, self.params1[choice])
25+
x = tlx.matmul(x, self.params2[0])
26+
x = tlx.matmul(x, self.params2[1])
27+
return x
28+
29+
input = tlx.nn.Input(shape=(5,5))
30+
net = MyModule()
31+
32+
output = net(input, choice = 'right')
33+
print(output)

tensorlayerx/backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@
147147
from .ops import diag
148148
from .ops import mask_select
149149
from .ops import eye
150+
from .ops import einsum
151+
from .ops import set_device
150152
# dtype
151153
from .ops import (
152154
DType, float16, float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool, complex64,

tensorlayerx/backend/ops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@
8484
from .load_backend import QuanConvBn
8585

8686
# load ops
87-
from .load_backend import einsum
8887
from .load_backend import Variable
8988
from .load_backend import matmul
9089
from .load_backend import add
@@ -200,6 +199,8 @@
200199
from .load_backend import diag
201200
from .load_backend import mask_select
202201
from .load_backend import eye
202+
from .load_backend import einsum
203+
from .load_backend import set_device
203204
# dtype
204205
from .load_backend import (
205206
DType, float16, float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool, complex64,

tensorlayerx/backend/ops/load_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
import mindspore.context as context
5959
import os
6060
os.environ['DEVICE_ID'] = '0'
61-
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU'),
61+
context.set_context(mode=context.PYNATIVE_MODE),
6262
# context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU'),
6363
# enable_task_sink=True, enable_loop_sink=True)
6464
# context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend')

tensorlayerx/backend/ops/mindspore_backend.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1838,4 +1838,10 @@ def __init__(self, equation):
18381838
self.einsum = ms.ops.Einsum(equation)
18391839

18401840
def __call__(self, *args):
1841-
return self.einsum(tuple(args))
1841+
return self.einsum(tuple(args))
1842+
1843+
def set_device(device = 'GPU', id = 0):
1844+
if device not in ['GPU', 'CPU', 'Ascend']:
1845+
raise ValueError ("In mindspore, only support 'CPU', 'GPU' and 'Ascend'.")
1846+
ms.context.set_context(device_target=device)
1847+
ms.context.set_context(device_id = id)

tensorlayerx/backend/ops/paddle_backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,3 +1856,9 @@ def __init__(self, equation):
18561856
def __call__(self, *args):
18571857
return einsum(self.equation, *args)
18581858

1859+
1860+
def set_device(device = 'GPU', id = 0):
1861+
device = device.lower()
1862+
if device == 'GPU':
1863+
device = device + ':' + str(id)
1864+
paddle.device.set_device(device)

tensorlayerx/backend/ops/tensorflow_backend.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3867,4 +3867,27 @@ def __init__(self, equation):
38673867
self.equation = equation
38683868

38693869
def __call__(self, *args):
3870-
return tf.einsum(self.equation, *args)
3870+
return tf.einsum(self.equation, *args)
3871+
3872+
def set_device(device = 'GPU', id = 0):
3873+
"""This function can specify the global device which the OP will run.
3874+
3875+
Parameters
3876+
----------
3877+
device : str
3878+
Specific running device. It can be 'CPU', 'GPU' and 'Ascend'(In mindspore backend).
3879+
id : int
3880+
Device id.
3881+
3882+
"""
3883+
if device not in ['GPU', 'CPU']:
3884+
raise ValueError ("Only support 'CPU', 'GPU'.")
3885+
if device == 'GPU':
3886+
gpus = tf.config.experimental.list_physical_devices('GPU')
3887+
if gpus:
3888+
try:
3889+
for gpu in gpus:
3890+
tf.config.experimental.set_memory_growth(gpu, True)
3891+
tf.config.experimental.set_visible_devices(gpus[id], 'GPU')
3892+
except RuntimeError as e:
3893+
print(e)

0 commit comments

Comments
 (0)