Skip to content

Commit 5b816ff

Browse files
committed
Added The argument can be an int or a tuple
1 parent 41340b3 commit 5b816ff

23 files changed

+233
-164
lines changed

tensorlayerx/backend/ops/mindspore_nn.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -529,13 +529,12 @@ class Conv2D(Cell):
529529
def __init__(self, strides, padding, data_format='NHWC', dilations=None, out_channel=None, k_size=None):
530530
super(Conv2D, self).__init__()
531531
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
532-
533532
if self.data_format is 'NHWC':
534-
self.ms_stride = strides[1]
535-
self.ms_dilation = dilations[1]
533+
self._stride = (strides[1], strides[2])
534+
self._dilation = (dilations[1], dilations[2])
536535
elif self.data_format is 'NCHW':
537-
self.ms_stride = strides[2]
538-
self.ms_dilation = dilations[2]
536+
self._stride = (strides[2], strides[3])
537+
self._dilation = (dilations[2], dilations[3])
539538

540539
self.conv2d = P.Conv2D(
541540
out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride,
@@ -582,12 +581,12 @@ def __init__(self, strides, padding, data_format='NDHWC', dilations=None, out_ch
582581
self.data_format, self.padding = preprocess_3d_format(data_format, padding)
583582

584583
if self.data_format is 'NDHWC':
585-
self.ms_stride = strides[1]
586-
self.ms_dilation = dilations[1]
584+
self.ms_stride = (strides[1], strides[2], strides[3])
585+
self.ms_dilation = (dilations[1], dilations[2], dilations[3])
587586
raise NotImplementedError("The optional value for data format. Currently only support “NCDHW”.")
588587
elif self.data_format is 'NCDHW':
589-
self.ms_stride = strides[2]
590-
self.ms_dilation = dilations[2]
588+
self.ms_stride = (strides[2], strides[3], strides[4])
589+
self.ms_dilation = (dilations[2], dilations[3], dilations[4])
591590

592591
self.conv3d = P.Conv3D(
593592
out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride,

tensorlayerx/backend/ops/paddle_backend.py

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,40 +11,61 @@
1111
from .paddle_nn import nchw_to_nhwc, nhwc_to_nchw, preprocess_2d_format, preprocess_1d_format, preprocess_3d_format
1212
import random
1313

14-
15-
_dtypeDict = {
16-
'DType': paddle.dtype,
17-
'float16': paddle.float16,
18-
'float32': paddle.float32,
19-
'float64': paddle.float64,
20-
'int8': paddle.int8,
21-
'int16': paddle.int16,
22-
'int32': paddle.int32,
23-
'int64': paddle.int64,
24-
'uint8': paddle.uint8,
25-
'uint16': None,
26-
'uint32': None,
27-
'uint64': None,
28-
'bool': paddle.bool,
29-
'complex64': paddle.complex64,
30-
'complex128': paddle.complex128
31-
}
32-
# TODO NotImplemented
33-
DType = paddle.dtype
34-
float16 = paddle.float16
35-
float32 = paddle.float32
36-
float64 = paddle.float64
37-
int8 = paddle.int8
38-
int16 = paddle.int16
39-
int32 = paddle.int32
40-
int64 = paddle.int64
41-
uint8 = paddle.uint8
42-
uint16 = None
43-
uint32 = None
44-
uint64 = None
45-
bool = paddle.bool
46-
complex64 = paddle.complex64
47-
complex128 = paddle.complex128
14+
if paddle.__version__ < '2.2.2':
15+
_dtypeDict = [
16+
"float16", "float32", "float64", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "bool",
17+
"complex64", "complex128"
18+
]
19+
# TODO NotImplemented
20+
DType = None
21+
float16 = "float16"
22+
float32 = "float32"
23+
float64 = "float64"
24+
int8 = "int8"
25+
int16 = "int16"
26+
int32 = "int32"
27+
int64 = "int64"
28+
uint8 = "uint8"
29+
uint16 = "uint16"
30+
uint32 = "uint32"
31+
uint64 = "uint64"
32+
bool = "bool"
33+
complex64 = "complex64"
34+
complex128 = "complex128"
35+
else:
36+
_dtypeDict = {
37+
'DType': paddle.dtype ,
38+
'float16': paddle.float16,
39+
'float32': paddle.float32,
40+
'float64': paddle.float64,
41+
'int8': paddle.int8,
42+
'int16': paddle.int16,
43+
'int32': paddle.int32,
44+
'int64': paddle.int64,
45+
'uint8': paddle.uint8,
46+
'uint16': None,
47+
'uint32': None,
48+
'uint64': None,
49+
'bool': paddle.bool,
50+
'complex64': paddle.complex64,
51+
'complex128': paddle.complex128
52+
}
53+
# TODO NotImplemented
54+
DType = paddle.dtype
55+
float16 = paddle.float16
56+
float32 = paddle.float32
57+
float64 = paddle.float64
58+
int8 = paddle.int8
59+
int16 = paddle.int16
60+
int32 = paddle.int32
61+
int64 = paddle.int64
62+
uint8 = paddle.uint8
63+
uint16 = None
64+
uint32 = None
65+
uint64 = None
66+
bool = paddle.bool
67+
complex64 = paddle.complex64
68+
complex128 = paddle.complex128
4869

4970

5071
def _getter(init_fn, **kwargs):

tensorlayerx/backend/ops/paddle_nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,10 +548,10 @@ class Conv3D(object):
548548

549549
def __init__(self, strides, padding, data_format='NDHWC', dilations=None, out_channel=None, k_size=None):
550550
self.data_format, self.padding = preprocess_3d_format(data_format, padding)
551-
if data_format is 'NDHWC':
551+
if self.data_format is 'NDHWC':
552552
self._strides = (strides[1], strides[2], strides[3])
553553
self._dilations = (dilations[1], dilations[2], dilations[3])
554-
elif data_format is 'NCDHW':
554+
elif self.data_format is 'NCDHW':
555555
self._strides = (strides[2], strides[3], strides[4])
556556
self._dilations = (dilations[2], dilations[3], dilations[4])
557557

tensorlayerx/backend/ops/tensorflow_nn.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ def preprocess_1d_format(data_format, padding):
5454
-------
5555
str "NWC" or "NCW" and "SAME" or "VALID"
5656
"""
57-
if data_format in ["channels_last", "NWC"]:
57+
if data_format in ["channels_last", "NWC", 'NLC']:
5858
data_format = "NWC"
59-
elif data_format in ["channels_first", "NCW"]:
59+
elif data_format in ["channels_first", "NCW", 'NCL']:
6060
data_format = "NCW"
6161
elif data_format == None:
6262
data_format = None
@@ -399,9 +399,11 @@ class BiasAdd(object):
399399
A Tensor with the same type as value.
400400
"""
401401

402-
def __init__(self, data_format=None):
403-
self.data_format, _ = preprocess_2d_format(data_format, None)
404-
402+
def __init__(self, data_format='channels_last'):
403+
if data_format in ['channels_first', 'NCL', 'NCW', 'NCHW', 'NCDHW']:
404+
self.data_format = "NCHW"
405+
elif data_format in ['channels_last', 'NLC', 'NWC', 'NHWC', 'NDHWC']:
406+
self.data_format = "NHWC"
405407
def __call__(self, x, bias):
406408
return tf.nn.bias_add(x, bias, data_format=self.data_format)
407409

tensorlayerx/backend/ops/torch_nn.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -573,9 +573,13 @@ def same_padding(input, weight, strides, dilations):
573573
class Conv2D(object):
574574

575575
def __init__(self, strides, padding, data_format='NHWC', dilations=None, out_channel=None, k_size=None, groups=1):
576-
self.strides = (strides[1], strides[2])
577-
self.dilations = (dilations[1], dilations[2])
578576
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
577+
if self.data_format is 'NHWC':
578+
self.strides = (strides[1], strides[2])
579+
self.dilations = (dilations[1], dilations[2])
580+
elif self.data_format is 'NCHW':
581+
self.strides = (strides[2], strides[3])
582+
self.dilations = (dilations[2], dilations[3])
579583
self.groups = groups
580584

581585
def __call__(self, input, filters):
@@ -644,13 +648,14 @@ def conv2d(input, filters, strides, padding, data_format='NHWC', dilations=None)
644648
class Conv3D(object):
645649

646650
def __init__(self, strides, padding, data_format='NDHWC', dilations=None, out_channel=None, k_size=None):
647-
if data_format is 'NDHWC':
651+
self.data_format, self.padding = preprocess_3d_format(data_format, padding)
652+
if self.data_format is 'NDHWC':
648653
self._strides = (strides[1], strides[2], strides[3])
649654
self._dilations = (dilations[1], dilations[2], dilations[3])
650-
elif data_format is 'NCDHW':
655+
elif self.data_format is 'NCDHW':
651656
self._strides = (strides[2], strides[3], strides[4])
652657
self._dilations = (dilations[2], dilations[3], dilations[4])
653-
self.data_format, self.padding = preprocess_3d_format(data_format, padding)
658+
654659

655660
def __call__(self, input, filters):
656661
if self.data_format == 'NDHWC':

tensorlayerx/nn/core/common.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,22 @@
3939
}
4040

4141

42+
def check_parameter(parameter, dim='2d'):
43+
if dim == '2d':
44+
if isinstance(parameter, int):
45+
out = (parameter, parameter)
46+
else:
47+
out = parameter
48+
elif dim == '3d':
49+
if isinstance(parameter, int):
50+
out = (parameter, parameter, parameter)
51+
else:
52+
out = parameter
53+
else:
54+
raise ("dim must be 2d or 3d.")
55+
return out
56+
57+
4258
def str2init(initializer):
4359
if isinstance(initializer, str):
4460
if initializer not in _initializers_dict.keys():
@@ -507,6 +523,19 @@ def construct_graph(inputs, outputs):
507523
return node_by_depth, all_layers
508524

509525

526+
def select_attrs(obj):
527+
attrs_dict = obj.__dict__
528+
attrs = {}
529+
_select_key = ['kernel_size', 'stride', 'act', 'padding', 'data_format', 'concat_dim']
530+
for k in _select_key:
531+
if k in attrs_dict:
532+
if k == 'act':
533+
attrs[k] = attrs_dict[k].__class__.__name__
534+
else:
535+
attrs[k] = attrs_dict[k]
536+
return attrs
537+
538+
510539
class ModuleNode(object):
511540
"""
512541
The class :class:`ModuleNode` class represents a conceptional node for a layer.
@@ -537,7 +566,7 @@ class ModuleNode(object):
537566
(1) Forwarding through the layer. (2) Update its input/output tensors.
538567
"""
539568

540-
def __init__(self, layer, node_index, in_nodes, in_tensors, out_tensors, in_tensor_idxes):
569+
def __init__(self, layer, node_index, in_nodes, in_tensors, out_tensors, in_tensor_idxes, attr):
541570
self.layer = layer
542571
self.node_index = node_index
543572
self.in_nodes = in_nodes
@@ -547,6 +576,7 @@ def __init__(self, layer, node_index, in_nodes, in_tensors, out_tensors, in_tens
547576
self.node_name = layer.name + "_node_{}".format(node_index)
548577

549578
self.in_tensors_idxes = in_tensor_idxes
579+
self.attr = attr
550580
self.visited = False
551581

552582
def __call__(self, inputs, **kwargs):

tensorlayerx/nn/core/core_mindspore.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#! /usr/bin/python
22
# -*- coding: utf-8 -*-
33

4-
from .common import str2act, str2init, random_normal, tolist, construct_graph, ModuleNode
4+
from .common import check_parameter, str2act, str2init, random_normal, tolist, construct_graph, ModuleNode, select_attrs
55
from .common import _save_weights, _load_weights, _save_standard_weights_dict, _load_standard_weights_dict
66
from mindspore.nn import Cell
77
import tensorlayerx as tlx
@@ -105,7 +105,7 @@ def _get_weights(self, var_name, shape, init=random_normal(), trainable=True, tr
105105
if len(shape) == 3:
106106
shape = shape[::-1]
107107
if len(shape) == 4:
108-
if not transposed and self.data_format == 'NHWC':
108+
if not transposed and self.data_format in ['NHWC', 'channels_last']:
109109
shape = (shape[3], shape[0], shape[1], shape[2])
110110
else:
111111
shape = (shape[3], shape[2], shape[0], shape[1])
@@ -265,6 +265,9 @@ def all_weights(self):
265265
def str_to_init(self, initializer):
266266
return str2init(initializer)
267267

268+
def check_param(self, param, dim='2d'):
269+
return check_parameter(param, dim)
270+
268271
def insert_child_to_layer(self, child_name, child):
269272
"""
270273
Adds a child layer to the current layer.
@@ -333,7 +336,7 @@ def _add_node(self, input_tensors, output_tensors):
333336
in_tensor_idxes = [tensor._info[1] for tensor in inputs_list]
334337
node_index = len(_global_layer_node)
335338

336-
new_node = ModuleNode(self, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes)
339+
new_node = ModuleNode(self, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes, select_attrs(self))
337340
_global_layer_node.append(new_node)
338341
for idx, tensor in enumerate(outputs_list):
339342
tensor._info = (new_node, idx)

tensorlayerx/nn/core/core_paddle.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: utf-8 -*-
33

44
import copy, six
5-
from .common import str2act, str2init, tolist, construct_graph, ModuleNode
5+
from .common import check_parameter, str2act, str2init, tolist, construct_graph, ModuleNode
66
from .common import _save_weights, _load_weights, _save_standard_weights_dict, _load_standard_weights_dict
77
from paddle.fluid import framework
88
from paddle.fluid.dygraph import Layer
@@ -288,6 +288,9 @@ def load_standard_weights(self, file_path, skip=False, reshape=False, format='np
288288
def str_to_init(self, initializer):
289289
return str2init(initializer)
290290

291+
def check_param(self, param, dim='2d'):
292+
return check_parameter(param, dim)
293+
291294
def insert_child_to_layer(self, child_name, child):
292295
"""
293296
Adds a child layer to the current layer.

tensorlayerx/nn/core/core_tensorflow.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#! /usr/bin/python
22
# -*- coding: utf-8 -*-
33

4-
from .common import str2act, str2init, tolist, construct_graph, ModuleNode
4+
from .common import check_parameter, str2act, str2init, tolist, construct_graph, ModuleNode, select_attrs
55
from .common import _save_weights, _load_weights, _save_standard_weights_dict, _load_standard_weights_dict
66
from collections import OrderedDict, abc as container_abcs
77
from collections import OrderedDict
@@ -582,6 +582,9 @@ def init_build(self, *inputs, **kwargs):
582582
def str_to_init(self, initializer):
583583
return str2init(initializer)
584584

585+
def check_param(self, param, dim='2d'):
586+
return check_parameter(param, dim)
587+
585588
def build_graph(self, *inputs, **kwargs):
586589
# Add nodes only when the composition is needed.
587590
layers = self.layers_and_names(name_prefix='')
@@ -621,7 +624,7 @@ def _add_node(self, input_tensors, output_tensors):
621624
in_tensor_idxes = [tensor._info[1] for tensor in inputs_list]
622625
node_index = len(_global_layer_node)
623626

624-
new_node = ModuleNode(self, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes)
627+
new_node = ModuleNode(self, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes, select_attrs(self))
625628
_global_layer_node.append(new_node)
626629
for idx, tensor in enumerate(outputs_list):
627630
tensor._info = (new_node, idx)

tensorlayerx/nn/core/core_torch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: utf-8 -*-
33

44
from torch.nn import Module as T_Module
5-
from .common import str2act, str2init, tolist, construct_graph, ModuleNode
5+
from .common import check_parameter, str2act, str2init, tolist, construct_graph, ModuleNode, select_attrs
66
from .common import _save_weights, _load_weights, _save_standard_weights_dict, _load_standard_weights_dict
77
from torch.nn.parameter import Parameter
88
from collections import OrderedDict
@@ -173,6 +173,9 @@ def load_standard_weights(self, file_path, skip=False, reshape=False, format='np
173173
def str_to_init(self, initializer):
174174
return str2init(initializer)
175175

176+
def check_param(self, param, dim='2d'):
177+
return check_parameter(param, dim)
178+
176179
def init_build(self, *inputs, **kwargs):
177180
"""
178181
(1) This method must be called when the Layer has no input in_channels.
@@ -220,7 +223,7 @@ def _add_node(self, input_tensors, output_tensors):
220223
in_tensor_idxes = [tensor._info[1] for tensor in inputs_list]
221224
node_index = len(_global_layer_node)
222225

223-
new_node = ModuleNode(self, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes)
226+
new_node = ModuleNode(self, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes, select_attrs(self))
224227
_global_layer_node.append(new_node)
225228
for idx, tensor in enumerate(outputs_list):
226229
tensor._info = (new_node, idx)

0 commit comments

Comments
 (0)