Skip to content

Commit ba5c605

Browse files
authored
Added TF collection of list parameters (#8)
1 parent b7e7dc2 commit ba5c605

File tree

4 files changed

+82
-38
lines changed

4 files changed

+82
-38
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
paddlepaddle==2.0.2
1+
paddlepaddle==2.2.0

tensorlayerx/nn/core/core_tensorflow.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class Module(object):
5555
def __init__(self, name=None, act=None, *args, **kwargs):
5656
self._params = OrderedDict()
5757
self._layers = OrderedDict()
58-
self._params_tuple = OrderedDict()
58+
self._params_list = OrderedDict()
5959
self._params_status = OrderedDict()
6060
self._parameter_layout_dict = {}
6161
self._create_time = int(time.time() * 1e9)
@@ -146,19 +146,16 @@ def __setattr__(self, name, value):
146146
raise TypeError("Expected type is Module, but got Parameter.")
147147
self.insert_param_to_layer(name, value)
148148

149+
elif isinstance(value, ParameterList):
150+
self.set_attr_for_parameter_tuple(name, value)
151+
149152
elif isinstance(value, Module):
150153
if layers is None:
151154
raise AttributeError("Can not assign layers before Module.__init__() call.")
152155
if name in self.__dict__:
153156
del self.__dict__[name]
154157
if params and name in params:
155158
raise TypeError("Expected type is Parameter, but got Module.")
156-
# TODO Automatic shape inference when the user does not enter inchannels.
157-
# if value._built is False:
158-
# raise AttributeError(
159-
# "The registered layer `{}` should be built in advance. "
160-
# "Do you forget to pass the keyword argument 'in_channels'? ".format(value.name)
161-
# )
162159
layers[name] = value
163160
else:
164161
object.__setattr__(self, name, value)
@@ -253,6 +250,27 @@ def _set_mode_for_layers(self, is_train):
253250
if isinstance(layer, Module):
254251
layer.is_train = is_train
255252

253+
def set_attr_for_parameter_tuple(self, name, value):
254+
"""Set attr for parameter in ParameterTuple."""
255+
params = self.__dict__.get('_params')
256+
params_list = self.__dict__.get('_params_list')
257+
if params is None:
258+
raise AttributeError("For 'Module', can not assign params before Module.__init__() is called.")
259+
exist_names = set("")
260+
261+
for item in value:
262+
self.insert_param_to_layer(item.name, item, check_name=False)
263+
if item.name in exist_names:
264+
raise ValueError("The value {} , its name '{}' already exists.".
265+
format(value, item.name))
266+
exist_names.add(item.name)
267+
268+
if name in self.__dict__:
269+
del self.__dict__[name]
270+
if name in params:
271+
del params[name]
272+
params_list[name] = value
273+
256274
def set_train(self):
257275
"""Set this network in training mode. After calling this method,
258276
all layers in network are in training mode, in particular, BatchNorm, Dropout, etc.
@@ -345,10 +363,10 @@ def __getattr__(self, name):
345363
params_status = self.__dict__['_params_status']
346364
if name in params_status:
347365
return params_status[name]
348-
if '_params_tuple' in self.__dict__:
349-
params_tuple = self.__dict__['_params_tuple']
350-
if name in params_tuple:
351-
para_list = params_tuple[name]
366+
if '_params_list' in self.__dict__:
367+
params_list = self.__dict__['_params_list']
368+
if name in params_list:
369+
para_list = params_list[name]
352370
return para_list
353371
raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name))
354372

@@ -988,7 +1006,7 @@ def update(self, modules):
9881006
self[m[0]] = m[1]
9891007

9901008

991-
def Parameter(data=None, requires_grad=True):
1009+
class Parameter(Module):
9921010
"""This function creates a parameter. The parameter is a learnable variable, which can have gradient, and can be optimized.
9931011
9941012
Parameters
@@ -1009,7 +1027,29 @@ def Parameter(data=None, requires_grad=True):
10091027
10101028
"""
10111029

1012-
return tf.Variable(initial_value=data, trainable=requires_grad)
1030+
def __new__(self, data=None, requires_grad=True, name=None):
1031+
if name is None:
1032+
prefix = self.__class__.__name__.lower()
1033+
1034+
if _global_layer_name_dict.get(prefix) is not None:
1035+
_global_layer_name_dict[prefix] += 1
1036+
name = prefix + '_' + str(_global_layer_name_dict[prefix])
1037+
else:
1038+
_global_layer_name_dict[prefix] = 0
1039+
name = prefix
1040+
while True:
1041+
if _global_layer_name_dict.get(name) is None:
1042+
break
1043+
_global_layer_name_dict[prefix] += 1
1044+
name = prefix + '_' + str(_global_layer_name_dict[prefix])
1045+
else:
1046+
if _global_layer_name_dict.get(name) is not None:
1047+
pass
1048+
else:
1049+
_global_layer_name_dict[name] = 0
1050+
1051+
self.name = name
1052+
return tf.Variable(initial_value=data, trainable=requires_grad, name=name)
10131053

10141054

10151055
class ParameterList(Module):
@@ -1068,10 +1108,10 @@ def __setitem__(self, idx, parameter):
10681108
idx = self._get_abs_string_index(idx)
10691109
self._params[str(idx)] = parameter
10701110

1071-
def __setattr__(self, key, value):
1072-
if not hasattr(self, key) and not isinstance(value, tf.Variable):
1073-
warnings.warn("Setting attributes on ParameterList is not supported.")
1074-
super(ParameterList, self).__setattr__(key, value)
1111+
# def __setattr__(self, key, value):
1112+
# if not hasattr(self, key) and not isinstance(value, tf.Variable):
1113+
# warnings.warn("Setting attributes on ParameterList is not supported.")
1114+
# super(ParameterList, self).__setattr__(key, value)
10751115

10761116
def __len__(self):
10771117
return len(self._params)
@@ -1251,7 +1291,6 @@ def update(self, parameters):
12511291
"ParameterDict update sequence element "
12521292
"#" + str(j) + " should be Iterable; is" + type(p).__name__
12531293
)
1254-
print(p)
12551294
if not len(p) == 2:
12561295
raise ValueError(
12571296
"ParameterDict update sequence element "

tensorlayerx/nn/core/core_torch.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -122,21 +122,22 @@ def _call_impl_tlx(self, *input, **kwargs):
122122

123123
result = self._call_impl(*input, **kwargs)
124124
return result
125-
# # TODO RNN enabled after repair
126-
# __call__: Callable[..., Any] = _call_impl_tlx
127-
#
128-
# def _named_members(self, get_members_fn, prefix='', recurse=True):
129-
# r"""Helper method for yielding various names + members of modules."""
130-
# memo = set()
131-
# modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
132-
# for module_prefix, module in modules:
133-
# members = get_members_fn(module)
134-
# for k, v in members:
135-
# if v is None or v in memo:
136-
# continue
137-
# memo.add(v)
138-
# name = module.name + '/' + k
139-
# yield name, v
125+
# TODO RNN enabled after repair
126+
__call__: Callable[..., Any] = _call_impl_tlx
127+
128+
def _named_members(self, get_members_fn, prefix='', recurse=True):
129+
r"""Helper method for yielding various names + members of modules."""
130+
memo = set()
131+
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
132+
for module_prefix, module in modules:
133+
members = get_members_fn(module)
134+
for k, v in members:
135+
if v is None or v in memo:
136+
continue
137+
memo.add(v)
138+
name = module.name + '/' + k
139+
yield name, v
140+
140141

141142
@property
142143
def all_weights(self):

tensorlayerx/nn/layers/recurrent.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -448,10 +448,10 @@ def __repr__(self):
448448

449449
def build(self, inputs_shape):
450450
bidirect = 2 if self.bidirectional else 1
451-
self.weight_ih = ParameterList()
452-
self.weight_hh = ParameterList()
453-
self.bias_ih = ParameterList()
454-
self.bias_hh =ParameterList()
451+
self.weight_ih = []
452+
self.weight_hh = []
453+
self.bias_ih = []
454+
self.bias_hh = []
455455
stdv = 1.0 / np.sqrt(self.hidden_size)
456456
_init = tlx.nn.initializers.RandomUniform(minval=-stdv, maxval=stdv)
457457
if self.mode == 'LSTM':
@@ -488,6 +488,10 @@ def build(self, inputs_shape):
488488
var_name='bias_hh_l{}{}'.format(layer, suffix), shape=(gate_size, ), init=_init
489489
)
490490
)
491+
self.weight_ih = ParameterList(self.weight_ih)
492+
self.weight_hh = ParameterList(self.weight_hh)
493+
self.bias_ih = ParameterList(self.bias_ih)
494+
self.bias_hh =ParameterList(self.bias_hh)
491495
self.rnn = tlx.ops.rnnbase(
492496
mode=self.mode, input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers,
493497
bias=self.bias, batch_first=self.batch_first, dropout=self.dropout, bidirectional=self.bidirectional,

0 commit comments

Comments
 (0)