@@ -55,7 +55,7 @@ class Module(object):
5555 def __init__ (self , name = None , act = None , * args , ** kwargs ):
5656 self ._params = OrderedDict ()
5757 self ._layers = OrderedDict ()
58- self ._params_tuple = OrderedDict ()
58+ self ._params_list = OrderedDict ()
5959 self ._params_status = OrderedDict ()
6060 self ._parameter_layout_dict = {}
6161 self ._create_time = int (time .time () * 1e9 )
@@ -146,19 +146,16 @@ def __setattr__(self, name, value):
146146 raise TypeError ("Expected type is Module, but got Parameter." )
147147 self .insert_param_to_layer (name , value )
148148
149+ elif isinstance (value , ParameterList ):
150+ self .set_attr_for_parameter_tuple (name , value )
151+
149152 elif isinstance (value , Module ):
150153 if layers is None :
151154 raise AttributeError ("Can not assign layers before Module.__init__() call." )
152155 if name in self .__dict__ :
153156 del self .__dict__ [name ]
154157 if params and name in params :
155158 raise TypeError ("Expected type is Parameter, but got Module." )
156- # TODO Automatic shape inference when the user does not enter inchannels.
157- # if value._built is False:
158- # raise AttributeError(
159- # "The registered layer `{}` should be built in advance. "
160- # "Do you forget to pass the keyword argument 'in_channels'? ".format(value.name)
161- # )
162159 layers [name ] = value
163160 else :
164161 object .__setattr__ (self , name , value )
@@ -253,6 +250,27 @@ def _set_mode_for_layers(self, is_train):
253250 if isinstance (layer , Module ):
254251 layer .is_train = is_train
255252
253+ def set_attr_for_parameter_tuple (self , name , value ):
254+ """Set attr for parameter in ParameterTuple."""
255+ params = self .__dict__ .get ('_params' )
256+ params_list = self .__dict__ .get ('_params_list' )
257+ if params is None :
258+ raise AttributeError ("For 'Module', can not assign params before Module.__init__() is called." )
259+ exist_names = set ("" )
260+
261+ for item in value :
262+ self .insert_param_to_layer (item .name , item , check_name = False )
263+ if item .name in exist_names :
264+ raise ValueError ("The value {} , its name '{}' already exists." .
265+ format (value , item .name ))
266+ exist_names .add (item .name )
267+
268+ if name in self .__dict__ :
269+ del self .__dict__ [name ]
270+ if name in params :
271+ del params [name ]
272+ params_list [name ] = value
273+
256274 def set_train (self ):
257275 """Set this network in training mode. After calling this method,
258276 all layers in network are in training mode, in particular, BatchNorm, Dropout, etc.
@@ -345,10 +363,10 @@ def __getattr__(self, name):
345363 params_status = self .__dict__ ['_params_status' ]
346364 if name in params_status :
347365 return params_status [name ]
348- if '_params_tuple ' in self .__dict__ :
349- params_tuple = self .__dict__ ['_params_tuple ' ]
350- if name in params_tuple :
351- para_list = params_tuple [name ]
366+ if '_params_list ' in self .__dict__ :
367+ params_list = self .__dict__ ['_params_list ' ]
368+ if name in params_list :
369+ para_list = params_list [name ]
352370 return para_list
353371 raise AttributeError ("'{}' object has no attribute '{}'." .format (type (self ).__name__ , name ))
354372
@@ -988,7 +1006,7 @@ def update(self, modules):
9881006 self [m [0 ]] = m [1 ]
9891007
9901008
991- def Parameter (data = None , requires_grad = True ):
1009+ class Parameter (Module ):
9921010 """This function creates a parameter. The parameter is a learnable variable, which can have gradient, and can be optimized.
9931011
9941012 Parameters
@@ -1009,7 +1027,29 @@ def Parameter(data=None, requires_grad=True):
10091027
10101028 """
10111029
1012- return tf .Variable (initial_value = data , trainable = requires_grad )
1030+ def __new__ (self , data = None , requires_grad = True , name = None ):
1031+ if name is None :
1032+ prefix = self .__class__ .__name__ .lower ()
1033+
1034+ if _global_layer_name_dict .get (prefix ) is not None :
1035+ _global_layer_name_dict [prefix ] += 1
1036+ name = prefix + '_' + str (_global_layer_name_dict [prefix ])
1037+ else :
1038+ _global_layer_name_dict [prefix ] = 0
1039+ name = prefix
1040+ while True :
1041+ if _global_layer_name_dict .get (name ) is None :
1042+ break
1043+ _global_layer_name_dict [prefix ] += 1
1044+ name = prefix + '_' + str (_global_layer_name_dict [prefix ])
1045+ else :
1046+ if _global_layer_name_dict .get (name ) is not None :
1047+ pass
1048+ else :
1049+ _global_layer_name_dict [name ] = 0
1050+
1051+ self .name = name
1052+ return tf .Variable (initial_value = data , trainable = requires_grad , name = name )
10131053
10141054
10151055class ParameterList (Module ):
@@ -1068,10 +1108,10 @@ def __setitem__(self, idx, parameter):
10681108 idx = self ._get_abs_string_index (idx )
10691109 self ._params [str (idx )] = parameter
10701110
1071- def __setattr__ (self , key , value ):
1072- if not hasattr (self , key ) and not isinstance (value , tf .Variable ):
1073- warnings .warn ("Setting attributes on ParameterList is not supported." )
1074- super (ParameterList , self ).__setattr__ (key , value )
1111+ # def __setattr__(self, key, value):
1112+ # if not hasattr(self, key) and not isinstance(value, tf.Variable):
1113+ # warnings.warn("Setting attributes on ParameterList is not supported.")
1114+ # super(ParameterList, self).__setattr__(key, value)
10751115
10761116 def __len__ (self ):
10771117 return len (self ._params )
@@ -1251,7 +1291,6 @@ def update(self, parameters):
12511291 "ParameterDict update sequence element "
12521292 "#" + str (j ) + " should be Iterable; is" + type (p ).__name__
12531293 )
1254- print (p )
12551294 if not len (p ) == 2 :
12561295 raise ValueError (
12571296 "ParameterDict update sequence element "
0 commit comments