Skip to content

Commit 226606c

Browse files
committed
Merge branch 'main' of github.com:tensorlayer/TensorLayerX into main
2 parents c131090 + 8cf6c5a commit 226606c

File tree

19 files changed

+579
-66
lines changed

19 files changed

+579
-66
lines changed

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414
[![Downloads](https://pepy.tech/badge/tensorlayerx/week)](https://pepy.tech/project/tensorlayerx/week)
1515
[![Docker Pulls](https://img.shields.io/docker/pulls/tensorlayer/tensorlayer.svg)](https://hub.docker.com/r/tensorlayer/tensorlayer/)
1616

17-
🇬🇧 TensorLayerX is a multi-backend AI framework, which supports TensorFlow, Pytorch, MindSpore, PaddlePaddle, OneFlow and Jittor as the backends, allowing users to run the code on different hardware like Nvidia-GPU and Huawei-Ascend. [supported layers](https://shimo.im/sheets/kJGCCTxXvqj99RGV/F5m5Z).
17+
🇬🇧 TensorLayerX is a multi-backend AI framework, which supports TensorFlow, Pytorch, MindSpore, PaddlePaddle, OneFlow and Jittor as the backends, allowing users to run the code on different hardware like Nvidia-GPU and Huawei-Ascend.
18+
This project is maintained by researchers from Peking University, Imperial College London, Princeton, Stanford, Tsinghua, Edinburgh and Peng Cheng Lab.
19+
[supported layers](https://shimo.im/sheets/kJGCCTxXvqj99RGV/F5m5Z).
1820

19-
🇨🇳 TensorLayerX 是一个跨平台开发框架,支持TensorFlow, Pytorch, MindSpore, PaddlePaddle, OneFlow和Jittor,用户不需要修改任何代码即可以运行在各类操作系统和AI硬件上(如Nvidia-GPU 和 Huawei-Ascend),并支持混合框架的开发。[支持列表](https://shimo.im/sheets/kJGCCTxXvqj99RGV/F5m5Z)
21+
🇨🇳 TensorLayerX 是一个跨平台开发框架,支持TensorFlow, Pytorch, MindSpore, PaddlePaddle, OneFlow和Jittor,用户不需要修改任何代码即可以运行在各类操作系统和AI硬件上(如Nvidia-GPU 和 Huawei-Ascend),并支持混合框架的开发。这个项目由北京大学、鹏城实验室、爱丁堡大学、帝国理工、清华、普林斯顿、斯坦福等机构的研究人员维护。
22+
[支持列表](https://shimo.im/sheets/kJGCCTxXvqj99RGV/F5m5Z)
2023

2124

2225

docs/modules/nn.rst

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ Layer list
1313
.. autosummary::
1414

1515
Module
16-
ModuleList
1716
Sequential
17+
ModuleList
18+
ModuleDict
1819

1920
Input
2021

@@ -125,13 +126,17 @@ Module
125126
^^^^^^^^^^^^^^^^
126127
.. autoclass:: Module
127128

129+
Sequential
130+
^^^^^^^^^^^^^^^^
131+
.. autoclass:: Sequential
132+
128133
ModuleList
129134
^^^^^^^^^^^^^^^^
130135
.. autoclass:: ModuleList
131136

132-
Sequential
137+
ModuleDict
133138
^^^^^^^^^^^^^^^^
134-
.. autoclass:: Sequential
139+
.. autoclass:: ModuleDict
135140

136141
.. -----------------------------------------------------------
137142
.. Input Layer

examples/basic_tutorials/tutorial_ModuleList.py renamed to examples/basic_tutorials/tutorial_ModuleContainer.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22
# -*- coding: utf-8 -*-
33

44
import os
5-
# os.environ['TL_BACKEND'] = 'tensorflow'
5+
os.environ['TL_BACKEND'] = 'tensorflow'
66
# os.environ['TL_BACKEND'] = 'mindspore'
7-
os.environ['TL_BACKEND'] = 'paddle'
8-
9-
from tensorlayerx.nn import Module, ModuleList, Linear
7+
# os.environ['TL_BACKEND'] = 'paddle'
8+
# os.environ['TL_BACKEND'] = 'torch'
9+
import numpy as np
10+
from tensorlayerx.nn import Module, ModuleList, Linear, ModuleDict
1011
import tensorlayerx as tlx
1112

13+
14+
####################### Holds submodules in a list ########################################
15+
1216
d1 = Linear(out_features=800, act=tlx.ReLU, in_features=784, name='linear1')
1317
d2 = Linear(out_features=800, act=tlx.ReLU, in_features=800, name='linear2')
1418
d3 = Linear(out_features=10, act=tlx.ReLU, in_features=800, name='linear3')
@@ -42,3 +46,21 @@ def forward(self, inputs):
4246
print(net.trainable_weights)
4347
print(net)
4448
print(net(tlx.nn.Input((10, 784))))
49+
50+
####################### Holds submodules in a Dict ########################################
51+
class MyModule(Module):
52+
53+
def __init__(self):
54+
super(MyModule, self).__init__()
55+
self.dict = ModuleDict({
56+
'linear1': Linear(out_features=800, act=tlx.ReLU, in_features=784, name='linear1'),
57+
'linear2': Linear(out_features=800, act=tlx.ReLU, in_features=800, name='linear2')
58+
})
59+
def forward(self, x, linear):
60+
x = self.dict[linear](x)
61+
return x
62+
63+
x = tlx.convert_to_tensor(np.ones(shape=(1,784)), dtype=tlx.float32)
64+
net = MyModule()
65+
x = net(x, 'linear1')
66+
print(x)

tensorlayerx/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
from tensorlayerx import losses
2222
from tensorlayerx import decorators
2323
from tensorlayerx import files
24-
from .utils import lazy_imports
25-
from . import utils
2624
from tensorlayerx import logging
2725
from tensorlayerx import model
2826
from tensorlayerx import optimizers

tensorlayerx/backend/ops/mindspore_backend.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,8 +1054,7 @@ def transpose(a, perm=None, conjugate=False):
10541054
A transposed Tensor.
10551055
"""
10561056
# TODO conjugate
1057-
trans_obj = P.Transpose()
1058-
outputs = trans_obj(a, perm)
1057+
outputs = msnp.transpose(a, perm)
10591058
print(outputs)
10601059

10611060

@@ -1816,4 +1815,9 @@ def set_seed(seed):
18161815

18171816
def is_tensor(x):
18181817

1819-
return isinstance(x, ms.Tensor)
1818+
return isinstance(x, ms.Tensor)
1819+
1820+
def tensor_scatter_nd_update(tensor, indices, updates):
1821+
1822+
op = ms.ops.TensorScatterUpdate()
1823+
return op(tensor, indices, updates)

tensorlayerx/backend/ops/paddle_backend.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,15 @@ def transpose(a, perm=None, conjugate=False):
824824
-------
825825
A transposed Tensor.
826826
"""
827-
827+
if perm == None:
828+
if len(a.shape) <= 2:
829+
return pd.t(a)
830+
if len(a.shape) == 3:
831+
perm = [2, 1, 0]
832+
if len(a.shape) == 4:
833+
perm = [3, 2, 1, 0]
834+
if len(a.shape) == 5:
835+
perm = [4, 3, 2, 1, 0]
828836
return pd.transpose(a, perm)
829837

830838

@@ -1503,4 +1511,12 @@ def set_seed(seed):
15031511

15041512
def is_tensor(x):
15051513

1506-
return pd.is_tensor(x)
1514+
return pd.is_tensor(x)
1515+
1516+
1517+
def tensor_scatter_nd_update(tensor, indices, updates):
1518+
a = pd.scatter_nd(indices, pd.ones_like(updates), tensor.shape)
1519+
a = pd.multiply(tensor, -a)
1520+
tensor = tensor + a
1521+
x = pd.scatter_nd_add(tensor, indices, updates)
1522+
return x

tensorlayerx/backend/ops/tensorflow_backend.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3597,4 +3597,21 @@ def is_tensor(x):
35973597
>>> tlx.ops.is_tensor(a)
35983598
"""
35993599

3600-
return tf.is_tensor(x)
3600+
return tf.is_tensor(x)
3601+
3602+
3603+
def tensor_scatter_nd_update(tensor, indices, updates):
3604+
"""
3605+
3606+
Parameters
3607+
----------
3608+
tensor
3609+
indices
3610+
updates
3611+
3612+
Returns
3613+
-------
3614+
3615+
"""
3616+
3617+
return tf.tensor_scatter_nd_update(tensor, indices, updates)

tensorlayerx/backend/ops/torch_backend.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1630,4 +1630,12 @@ def set_seed(seed):
16301630

16311631
def is_tensor(x):
16321632

1633-
return isinstance(x, torch.Tensor)
1633+
return isinstance(x, torch.Tensor)
1634+
1635+
def tensor_scatter_nd_update(tensor, indices, updates):
1636+
1637+
indices = torch.flatten(indices)
1638+
tensor[indices] = updates
1639+
return tensor
1640+
1641+

tensorlayerx/dataflow/utils.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88
import numbers
99
import itertools
1010
import multiprocessing
11-
import threading
1211
import queue
1312
from collections import namedtuple
1413
from dataclasses import dataclass
1514
import sys
1615
import traceback
1716

18-
1917
def default_convert(data):
2018
data_type = type(data)
2119
if isinstance(data, np.ndarray):
@@ -264,7 +262,6 @@ def __init__(self, loader):
264262
self._persistent_workers = loader.persistent_workers
265263
self._time_out = loader.time_out
266264
self._sampler_iter = iter(self._index_sampler)
267-
# self._pin_memory = loader.pin_memory
268265
self._num_yielded = 0
269266

270267
def __iter__(self):
@@ -321,7 +318,6 @@ def __init__(self, loader):
321318
self._worker_result_queue = multiprocessing.Queue()
322319
self._worker_done_event = multiprocessing.Event()
323320
self._worker_pids_set = False
324-
self._shutdown = False
325321

326322
self._index_queues = []
327323
self._workers = []
@@ -357,6 +353,7 @@ def _reset(self, loader, first_iter=False):
357353
while resume_iteration_cnt > 0:
358354
return_idx, return_data = self._get_data()
359355
if isinstance(return_idx, _ResumeIteration):
356+
assert return_data is None
360357
resume_iteration_cnt -= 1
361358
for _ in range(self._prefetch_factor * self._num_workers):
362359

@@ -469,26 +466,19 @@ def _shutdown_workers(self):
469466
if not self._shutdown:
470467
self._shutdown = True
471468
try:
472-
if hasattr(self, '_pin_memory_thread'):
473-
self._pin_memory_thread_done_event.set()
474-
self._worker_result_queue.put((None, None))
475-
self._pin_memory_thread.join()
476-
self._worker_result_queue.cancel_join_thread()
477-
self._worker_result_queue.close()
478-
479469
self._worker_done_event.set()
480470
for worker_id in range(len(self._workers)):
481471
if self._persistent_workers or self._workers_status[worker_id]:
482472
self._mark_worker_as_unavailable(worker_id, shutdown=True)
483473
for w in self._workers:
484474
w.join(timeout=5.0)
485-
if w.is_alive():
486-
w.terminate()
487475
for q in self._index_queues:
488476
q.cancel_join_thread()
489477
q.close()
490478
finally:
491-
pass
479+
for w in self._workers:
480+
if w.is_alive():
481+
w.terminate()
492482

493483
def __del__(self):
494484
self._shutdown_workers()
@@ -565,7 +555,15 @@ def _worker_loop(
565555
try:
566556
data = fetcher.fetch(index)
567557
except Exception as e:
568-
data = ExceptionWrapper(where="in DataLoader worker process {}".format(worker_id))
558+
if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iter:
559+
data = _IterableDatasetStopIteration(worker_id)
560+
iteration_end = True
561+
else:
562+
# It is important that we don't store exc_info in a variable.
563+
# `ExceptionWrapper` does the correct thing.
564+
# See NOTE [ Python Traceback Reference Cycle Problem ]
565+
data = ExceptionWrapper(
566+
where="in DataLoader worker process {}".format(worker_id))
569567
data_queue.put((idx, data))
570568
del data, idx, index, r
571569
except KeyboardInterrupt:

tensorlayerx/files/dataset_loaders/cyclegan_dataset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77

88
from tensorlayerx import logging
9-
from tensorlayerx.utils import visualize
9+
from tensorlayerx.vision import load_images
1010
from tensorlayerx.files.utils import (del_file, folder_exists, load_file_list, maybe_download_and_extract)
1111

1212
__all__ = ['load_cyclegan_dataset']
@@ -36,8 +36,7 @@ def load_cyclegan_dataset(filename='summer2winter_yosemite', path='data'):
3636
del_file(os.path.join(path, filename + '.zip'))
3737

3838
def load_image_from_folder(path):
39-
path_imgs = load_file_list(path=path, regx='\\.jpg', printable=False)
40-
return visualize.read_images(path_imgs, path=path, n_threads=10, printable=False)
39+
return load_images(path=path, n_threads=10)
4140

4241
im_train_A = load_image_from_folder(os.path.join(path, filename, "trainA"))
4342
im_train_B = load_image_from_folder(os.path.join(path, filename, "trainB"))

0 commit comments

Comments
 (0)