Skip to content

Commit 8e0b2fe

Browse files
committed
Added TF collection of tuple parameters
1 parent de7d9dc commit 8e0b2fe

File tree

5 files changed

+86
-25
lines changed

5 files changed

+86
-25
lines changed

tensorlayerx/nn/core/core_mindspore.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
import numpy
1313
from mindspore.common.api import _pynative_executor
1414
from collections import OrderedDict, abc as container_abcs
15-
__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict']
15+
from mindspore.nn.cell import ParameterTuple
16+
17+
__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict', 'ParameterTuple']
1618

1719
_global_layer_name_dict = {}
1820
_global_layer_node = []

tensorlayerx/nn/core/core_paddle.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@
1414
from paddle.fluid.dygraph import parallel_helper
1515
import paddle as pd
1616
from collections import OrderedDict, abc as container_abcs
17-
import tensorlayerx as tlx
18-
from queue import Queue
17+
from paddle.nn import ParameterList as ParameterTuple
1918

2019
_global_layer_name_dict = {}
2120
_global_layer_node = []
22-
23-
__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict']
21+
# TODO Need to implement ParameterTuple
22+
__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict', 'ParameterTuple']
2423

2524

2625
class Module(Layer):
@@ -744,7 +743,6 @@ def update(self, parameters):
744743
def __call__(self, input):
745744
raise RuntimeError('ParameterDict should not be called.')
746745

747-
748746
def _valid_index(layer_num, index):
749747
if not isinstance(index, int):
750748
raise TypeError("Index {} is not int type")

tensorlayerx/nn/core/core_tensorflow.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import tensorflow as tf
1111
from tensorlayerx.nn.layers.utils import (get_variable_with_initializer, random_normal)
1212

13-
__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict']
13+
__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict', 'ParameterTuple']
1414

1515
_global_layer_name_dict = {}
1616
_global_layer_node = []
@@ -55,6 +55,7 @@ class Module(object):
5555
def __init__(self, name=None, act=None, *args, **kwargs):
5656
self._params = OrderedDict()
5757
self._layers = OrderedDict()
58+
self._params_tuple = OrderedDict()
5859
self._params_status = OrderedDict()
5960
self._parameter_layout_dict = {}
6061
self._create_time = int(time.time() * 1e9)
@@ -145,6 +146,9 @@ def __setattr__(self, name, value):
145146
raise TypeError("Expected type is Module, but got Parameter.")
146147
self.insert_param_to_layer(name, value)
147148

149+
elif isinstance(value, ParameterTuple):
150+
self.set_attr_for_parameter_tuple(name, value)
151+
148152
elif isinstance(value, Module):
149153
if layers is None:
150154
raise AttributeError("Can not assign layers before Module.__init__() call.")
@@ -293,6 +297,27 @@ def _compute_shape(tensors):
293297
shape_mem = tlx.get_tensor_shape(tensors)
294298
return shape_mem
295299

300+
def set_attr_for_parameter_tuple(self, name, value):
301+
"""Set attr for parameter in ParameterTuple."""
302+
params = self.__dict__.get('_params')
303+
params_tuple = self.__dict__.get('_params_tuple')
304+
if params is None:
305+
raise AttributeError("For 'Module', can not assign params before Module.__init__() is called.")
306+
exist_names = set("")
307+
308+
for item in value:
309+
self.insert_param_to_layer(item.name, item, check_name=False)
310+
if item.name in exist_names:
311+
raise ValueError("The value {} , its name '{}' already exists.".
312+
format(value, item.name))
313+
exist_names.add(item.name)
314+
315+
if name in self.__dict__:
316+
del self.__dict__[name]
317+
if name in params:
318+
del params[name]
319+
params_tuple[name] = value
320+
296321
def insert_param_to_layer(self, param_name, param, check_name=True):
297322
"""
298323
Adds a parameter to the current layer.
@@ -344,6 +369,11 @@ def __getattr__(self, name):
344369
params_status = self.__dict__['_params_status']
345370
if name in params_status:
346371
return params_status[name]
372+
if '_params_tuple' in self.__dict__:
373+
params_tuple = self.__dict__['_params_tuple']
374+
if name in params_tuple:
375+
para_list = params_tuple[name]
376+
return para_list
347377
raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name))
348378

349379
def __delattr__(self, name):
@@ -1258,6 +1288,33 @@ def __call__(self, input):
12581288
raise RuntimeError('ParameterDict should not be called.')
12591289

12601290

1291+
class ParameterTuple(tuple):
1292+
"""
1293+
ParameterTuple for storing tuple of parameters.
1294+
"""
1295+
def __new__(cls, iterable):
1296+
data = tuple(iterable)
1297+
ids = set()
1298+
orders = {}
1299+
for x in data:
1300+
if not isinstance(x, tf.Variable):
1301+
raise TypeError(f"ParameterTuple input should be `Parameter` collection."
1302+
f"But got a {type(iterable)}, {iterable}")
1303+
if id(x) not in ids:
1304+
ids.add(id(x))
1305+
if x.name not in orders.keys():
1306+
orders[x.name] = [0, x]
1307+
else:
1308+
if isinstance(orders[x.name], list):
1309+
name = x.name
1310+
orders[name][1].name = name + "_" + str(0)
1311+
x.name = x.name + "_" + str(1)
1312+
orders[name] = 1
1313+
else:
1314+
orders[x.name] += 1
1315+
x.name = x.name + "_" + str(orders[x.name])
1316+
return tuple.__new__(ParameterTuple, tuple(data))
1317+
12611318
def _valid_index(layer_num, index):
12621319
if not isinstance(index, int):
12631320
raise TypeError("Index {} is not int type")

tensorlayerx/nn/core/core_torch.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
from collections import OrderedDict, abc as container_abcs
1313
import warnings
1414
import tensorlayerx as tlx
15+
from torch.nn import ParameterList as ParameterTuple
1516

1617
_global_layer_name_dict = {}
1718
_global_layer_node = []
1819

19-
__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict']
20+
__all__ = ['Module', 'Sequential', 'ModuleList', 'ModuleDict', 'Parameter', 'ParameterList', 'ParameterDict', 'ParameterTuple']
2021

2122

2223
class Module(T_Module):
@@ -122,21 +123,21 @@ def _call_impl_tlx(self, *input, **kwargs):
122123

123124
result = self._call_impl(*input, **kwargs)
124125
return result
125-
126-
__call__: Callable[..., Any] = _call_impl_tlx
127-
128-
def _named_members(self, get_members_fn, prefix='', recurse=True):
129-
r"""Helper method for yielding various names + members of modules."""
130-
memo = set()
131-
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
132-
for module_prefix, module in modules:
133-
members = get_members_fn(module)
134-
for k, v in members:
135-
if v is None or v in memo:
136-
continue
137-
memo.add(v)
138-
name = module.name + '/' + k
139-
yield name, v
126+
# # TODO RNN enabled after repair
127+
# __call__: Callable[..., Any] = _call_impl_tlx
128+
#
129+
# def _named_members(self, get_members_fn, prefix='', recurse=True):
130+
# r"""Helper method for yielding various names + members of modules."""
131+
# memo = set()
132+
# modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
133+
# for module_prefix, module in modules:
134+
# members = get_members_fn(module)
135+
# for k, v in members:
136+
# if v is None or v in memo:
137+
# continue
138+
# memo.add(v)
139+
# name = module.name + '/' + k
140+
# yield name, v
140141

141142
@property
142143
def all_weights(self):

tensorlayerx/nn/layers/recurrent.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import tensorlayerx as tlx
66
from tensorlayerx import logging
7-
from tensorlayerx.nn.core import Module
7+
from tensorlayerx.nn.core import Module, ParameterTuple
88

99
__all__ = [
1010
'RNN',
@@ -488,7 +488,10 @@ def build(self, inputs_shape):
488488
var_name='bias_hh_l{}{}'.format(layer, suffix), shape=(gate_size, ), init=_init
489489
)
490490
)
491-
491+
# self.weight_ih = ParameterTuple(self.w_ih)
492+
# self.weight_hh = ParameterTuple(self.w_hh)
493+
# self.bias_ih = ParameterTuple(self.b_ih)
494+
# self.bias_hh = ParameterTuple(self.b_hh)
492495
self.rnn = tlx.ops.rnnbase(
493496
mode=self.mode, input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers,
494497
bias=self.bias, batch_first=self.batch_first, dropout=self.dropout, bidirectional=self.bidirectional,

0 commit comments

Comments
 (0)