@@ -56,6 +56,7 @@ def __init__(self, name=None, act=None, *args, **kwargs):
5656 self ._params = OrderedDict ()
5757 self ._layers = OrderedDict ()
5858 self ._params_list = OrderedDict ()
59+ self ._params_dict = OrderedDict ()
5960 self ._params_status = OrderedDict ()
6061 self ._parameter_layout_dict = {}
6162 self ._create_time = int (time .time () * 1e9 )
@@ -105,6 +106,7 @@ def __init__(self, name=None, act=None, *args, **kwargs):
105106
106107 # weights check state
107108 self ._check = False
109+ self .trainable = True
108110
109111 def extend_repr (self ):
110112 """
@@ -149,6 +151,9 @@ def __setattr__(self, name, value):
149151 elif isinstance (value , ParameterList ):
150152 self .set_attr_for_parameter_tuple (name , value )
151153
154+ elif isinstance (value , ParameterDict ):
155+ self .set_attr_for_parameter_dict (name , value )
156+
152157 elif isinstance (value , Module ):
153158 if layers is None :
154159 raise AttributeError ("Can not assign layers before Module.__init__() call." )
@@ -250,6 +255,26 @@ def _set_mode_for_layers(self, is_train):
250255 if isinstance (layer , Module ):
251256 layer .is_train = is_train
252257
258+ def set_attr_for_parameter_dict (self , name , value ):
259+ """Set attr for parameter in ParameterDict."""
260+ params = self .__dict__ .get ('_params' )
261+ params_dict = self .__dict__ .get ('_params_dict' )
262+ if params is None :
263+ raise AttributeError ("For 'Module', can not assign params before Module.__init__() is called." )
264+ exist_names = set ("" )
265+ for item in value :
266+ self .insert_param_to_layer (item , value [item ], check_name = False )
267+ if item in exist_names :
268+ raise ValueError ("The value {} , its name '{}' already exists." .
269+ format (value [item ], item ))
270+ exist_names .add (item )
271+
272+ if name in self .__dict__ :
273+ del self .__dict__ [name ]
274+ if name in params :
275+ del params [name ]
276+ params_dict [name ] = value
277+
253278 def set_attr_for_parameter_tuple (self , name , value ):
254279 """Set attr for parameter in ParameterTuple."""
255280 params = self .__dict__ .get ('_params' )
@@ -368,6 +393,10 @@ def __getattr__(self, name):
368393 if name in params_list :
369394 para_list = params_list [name ]
370395 return para_list
396+ if '_params_dict' in self .__dict__ :
397+ params_dict = self .__dict__ ['_params_dict' ]
398+ if name in params_dict :
399+ return params_dict [name ]
371400 raise AttributeError ("'{}' object has no attribute '{}'." .format (type (self ).__name__ , name ))
372401
373402 def __delattr__ (self , name ):
@@ -1027,9 +1056,10 @@ class Parameter(Module):
10271056
10281057 """
10291058
1030- def __new__ (self , data = None , requires_grad = True , name = None ):
1059+ def __new__ (self , data = None , name = None ):
1060+ instance = super ().__new__ (self )
10311061 if name is None :
1032- prefix = self . __class__ . __name__ . lower ()
1062+ prefix = 'parameter'
10331063
10341064 if _global_layer_name_dict .get (prefix ) is not None :
10351065 _global_layer_name_dict [prefix ] += 1
@@ -1047,9 +1077,13 @@ def __new__(self, data=None, requires_grad=True, name=None):
10471077 pass
10481078 else :
10491079 _global_layer_name_dict [name ] = 0
1080+ if data is None :
1081+ return instance
1082+ else :
1083+ return instance (data , name )
10501084
1051- self . name = name
1052- return tf .Variable (initial_value = data , trainable = requires_grad , name = name )
1085+ def __call__ ( self , data = None , name = None , ** kwargs ):
1086+ return tf .Variable (initial_value = data , name = name )
10531087
10541088
10551089class ParameterList (Module ):
@@ -1219,10 +1253,10 @@ def __setitem__(self, key, parameter):
12191253 def __delitem__ (self , key ):
12201254 del self ._params [key ]
12211255
1222- def __setattr__ (self , key , value ):
1223- if not hasattr (self , key ) and not isinstance (value , tf .Variable ):
1224- warnings .warn ("Setting attributes on ParameterDict is not supported." )
1225- super (ParameterDict , self ).__setattr__ (key , value )
1256+ # def __setattr__(self, key, value):
1257+ # if not hasattr(self, key) and not isinstance(value, tf.Variable):
1258+ # warnings.warn("Setting attributes on ParameterDict is not supported.")
1259+ # super(ParameterDict, self).__setattr__(key, value)
12261260
12271261 def __len__ (self ) -> int :
12281262 return len (self ._params )
0 commit comments