Skip to content

Commit 1ae2442

Browse files
committed
Merge branch 'main' of github.com:tensorlayer/TensorLayerX into main
2 parents 78147ab + 796543c commit 1ae2442

File tree

7 files changed

+62
-51
lines changed

7 files changed

+62
-51
lines changed

tensorlayerx/backend/ops/paddle_backend.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,49 @@
22
# -*- coding: utf-8 -*-
33

44
from __future__ import absolute_import, division, print_function
5+
6+
import paddle
57
import paddle as pd
68
import paddle.nn as nn
79
import numpy as np
810
import paddle.nn.functional as F
911
from .paddle_nn import nchw_to_nhwc, nhwc_to_nchw, preprocess_2d_format, preprocess_1d_format, preprocess_3d_format
1012
import random
1113

12-
_dtypeDict = [
13-
"float16", "float32", "float64", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "bool",
14-
"complex64", "complex128"
15-
]
14+
15+
_dtypeDict = {
16+
'DType': paddle.dtype,
17+
'float16': paddle.float16,
18+
'float32': paddle.float32,
19+
'float64': paddle.float64,
20+
'int8': paddle.int8,
21+
'int16': paddle.int16,
22+
'int32': paddle.int32,
23+
'int64': paddle.int64,
24+
'uint8': paddle.uint8,
25+
'uint16': None,
26+
'uint32': None,
27+
'uint64': None,
28+
'bool': paddle.bool,
29+
'complex64': paddle.complex64,
30+
'complex128': paddle.complex128
31+
}
1632
# TODO NotImplemented
17-
DType = None
18-
float16 = "float16"
19-
float32 = "float32"
20-
float64 = "float64"
21-
int8 = "int8"
22-
int16 = "int16"
23-
int32 = "int32"
24-
int64 = "int64"
25-
uint8 = "uint8"
26-
uint16 = "uint16"
27-
uint32 = "uint32"
28-
uint64 = "uint64"
29-
bool = "bool"
30-
complex64 = "complex64"
31-
complex128 = "complex128"
33+
DType = paddle.dtype
34+
float16 = paddle.float16
35+
float32 = paddle.float32
36+
float64 = paddle.float64
37+
int8 = paddle.int8
38+
int16 = paddle.int16
39+
int32 = paddle.int32
40+
int64 = paddle.int64
41+
uint8 = paddle.uint8
42+
uint16 = None
43+
uint32 = None
44+
uint64 = None
45+
bool = paddle.bool
46+
complex64 = paddle.complex64
47+
complex128 = paddle.complex128
3248

3349

3450
def _getter(init_fn, **kwargs):

tensorlayerx/backend/ops/tensorflow_nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2765,12 +2765,12 @@ def __init__(self, data_format):
27652765
def __call__(self, input, weight):
27662766

27672767
pos = tf.nn.relu(input)
2768-
neg = -weight * tf.nn.relu(input)
2768+
neg = -tf.nn.sigmoid(weight) * tf.nn.relu(-input)
27692769
return pos + neg
27702770

27712771

27722772
def prelu(input, weight, data_format):
27732773

27742774
pos = tf.nn.relu(input)
2775-
neg = -weight * tf.nn.relu(input)
2775+
neg = -tf.nn.sigmoid(weight) * tf.nn.relu(-input)
27762776
return pos + neg

tensorlayerx/nn/core/core_mindspore.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,10 @@ class ModuleList(Module):
478478
>>> layer_list.append(d3)
479479
"""
480480

481-
def __init__(self, args):
482-
Module.__init__(self)
483-
self.extend(args)
481+
def __init__(self, modules = None):
482+
super(ModuleList, self).__init__()
483+
if modules is not None:
484+
self.extend(modules)
484485

485486
def __getitem__(self, index):
486487
if isinstance(index, slice):
@@ -555,9 +556,10 @@ def forward(self, *inputs):
555556

556557
class ModuleDict(Module):
557558

558-
def __init__(self, modules):
559+
def __init__(self, modules = None):
559560
super(ModuleDict, self).__init__()
560-
self.update(modules)
561+
if modules is not None:
562+
self.update(modules)
561563

562564
def __getitem__(self, key):
563565

tensorlayerx/nn/core/core_paddle.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,10 @@ def forward(self, input_data):
416416

417417
class ModuleList(Module):
418418

419-
def __init__(self, args):
419+
def __init__(self, modules = None):
420420
super(ModuleList, self).__init__()
421-
if args is not None:
422-
self.extend(args)
421+
if modules is not None:
422+
self.extend(modules)
423423

424424
def __getitem__(self, index):
425425
if isinstance(index, slice):
@@ -501,9 +501,10 @@ def forward(self, *inputs):
501501

502502
class ModuleDict(Module):
503503

504-
def __init__(self, modules):
504+
def __init__(self, modules = None):
505505
super(ModuleDict, self).__init__()
506-
self.update(modules)
506+
if modules is not None:
507+
self.update(modules)
507508

508509
def __getitem__(self, key):
509510

tensorlayerx/nn/core/core_tensorflow.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -760,9 +760,10 @@ class ModuleList(Module):
760760
>>> layer_list.append(d3)
761761
"""
762762

763-
def __init__(self, args):
763+
def __init__(self, modules = None):
764764
super(ModuleList, self).__init__()
765-
self.extend(args)
765+
if modules is not None:
766+
self.extend(modules)
766767

767768
def __getitem__(self, index):
768769
if isinstance(index, slice):
@@ -891,9 +892,10 @@ class ModuleDict(Module):
891892
892893
"""
893894

894-
def __init__(self, modules):
895+
def __init__(self, modules = None):
895896
super(ModuleDict, self).__init__()
896-
self.update(modules)
897+
if modules is not None:
898+
self.update(modules)
897899

898900
def __getitem__(self, key):
899901

tensorlayerx/nn/core/core_torch.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,9 +376,10 @@ class ModuleList(Module):
376376
>>> layer_list.append(d3)
377377
"""
378378

379-
def __init__(self, args):
379+
def __init__(self, modules = None):
380380
super(ModuleList, self).__init__()
381-
self.extend(args)
381+
if modules is not None:
382+
self.extend(modules)
382383

383384
def __getitem__(self, index):
384385
if isinstance(index, slice):
@@ -460,9 +461,10 @@ def forward(self, *inputs):
460461

461462
class ModuleDict(Module):
462463

463-
def __init__(self, modules):
464+
def __init__(self, modules = None):
464465
super(ModuleDict, self).__init__()
465-
self.update(modules)
466+
if modules is not None:
467+
self.update(modules)
466468

467469
def __getitem__(self, key):
468470

tensorlayerx/nn/layers/convolution/simplified_conv.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,9 @@ def __repr__(self):
109109

110110
def build(self, inputs_shape):
111111
if self.data_format == 'channels_last':
112-
self.data_format = 'NWC'
113112
if self.in_channels is None:
114113
self.in_channels = inputs_shape[-1]
115114
elif self.data_format == 'channels_first':
116-
self.data_format = 'NCW'
117115
if self.in_channels is None:
118116
self.in_channels = inputs_shape[1]
119117
else:
@@ -253,13 +251,11 @@ def __repr__(self):
253251

254252
def build(self, inputs_shape):
255253
if self.data_format == 'channels_last':
256-
self.data_format = 'NHWC'
257254
if self.in_channels is None:
258255
self.in_channels = inputs_shape[-1]
259256
self._strides = [1, self._strides[0], self._strides[1], 1]
260257
self._dilation_rate = [1, self._dilation_rate[0], self._dilation_rate[1], 1]
261258
elif self.data_format == 'channels_first':
262-
self.data_format = 'NCHW'
263259
if self.in_channels is None:
264260
self.in_channels = inputs_shape[1]
265261
self._strides = [1, 1, self._strides[0], self._strides[1]]
@@ -400,13 +396,11 @@ def __repr__(self):
400396

401397
def build(self, inputs_shape):
402398
if self.data_format == 'channels_last':
403-
self.data_format = 'NDHWC'
404399
if self.in_channels is None:
405400
self.in_channels = inputs_shape[-1]
406401
self._strides = [1, self._strides[0], self._strides[1], self._strides[2], 1]
407402
self._dilation_rate = [1, self.dilation[0], self.dilation[1], self.dilation[2], 1]
408403
elif self.data_format == 'channels_first':
409-
self.data_format = 'NCDHW'
410404
if self.in_channels is None:
411405
self.in_channels = inputs_shape[1]
412406
self._strides = [1, 1, self._strides[0], self._strides[1], self._strides[2]]
@@ -548,11 +542,9 @@ def __repr__(self):
548542

549543
def build(self, inputs_shape):
550544
if self.data_format == 'channels_last':
551-
self.data_format = 'NWC'
552545
if self.in_channels is None:
553546
self.in_channels = inputs_shape[-1]
554547
elif self.data_format == 'channels_first':
555-
self.data_format = 'NCW'
556548
if self.in_channels is None:
557549
self.in_channels = inputs_shape[1]
558550
else:
@@ -697,11 +689,9 @@ def __repr__(self):
697689

698690
def build(self, inputs_shape):
699691
if self.data_format == 'channels_last':
700-
self.data_format = 'NHWC'
701692
if self.in_channels is None:
702693
self.in_channels = inputs_shape[-1]
703694
elif self.data_format == 'channels_first':
704-
self.data_format = 'NCHW'
705695
if self.in_channels is None:
706696
self.in_channels = inputs_shape[1]
707697
else:
@@ -840,11 +830,9 @@ def __repr__(self):
840830

841831
def build(self, inputs_shape):
842832
if self.data_format == 'channels_last':
843-
self.data_format = 'NDHWC'
844833
if self.in_channels is None:
845834
self.in_channels = inputs_shape[-1]
846835
elif self.data_format == 'channels_first':
847-
self.data_format = 'NCDHW'
848836
if self.in_channels is None:
849837
self.in_channels = inputs_shape[1]
850838
else:

0 commit comments

Comments
 (0)