@@ -93,9 +93,9 @@ def __init__(self, name=None, act=None, *args, **kwargs):
9393 self ._nodes_fixed = False
9494
9595 # Layer weight state
96- self ._all_weights = []
97- self ._trainable_weights = []
98- self ._nontrainable_weights = []
96+ self ._all_weights = None
97+ self ._trainable_weights = None
98+ self ._nontrainable_weights = None
9999
100100 # layer forward state
101101 self ._forward_state = False
@@ -333,15 +333,19 @@ def trainable_weights(self):
333333
334334 """
335335
336- self .get_weights ()
337- layers = self .layers_and_names (name_prefix = '' )
338- for layer_name , layer in layers :
339- params = layer ._params .items ()
340- params_status = layer ._params_status .items ()
341- params_zip = zip (params , params_status )
342- for params , params_status in params_zip :
343- if params_status [1 ] == True :
344- self ._trainable_weights .append (params [1 ])
336+ if self ._trainable_weights is not None and len (self ._trainable_weights ) > 0 :
337+ # self._trainable_weights already extracted, so do nothing
338+ pass
339+ else :
340+ self ._trainable_weights = []
341+ layers = self .layers_and_names (name_prefix = '' )
342+ for layer_name , layer in layers :
343+ params = layer ._params .items ()
344+ params_status = layer ._params_status .items ()
345+ params_zip = zip (params , params_status )
346+ for params , params_status in params_zip :
347+ if params_status [1 ] == True :
348+ self ._trainable_weights .append (params [1 ])
345349 return self ._trainable_weights
346350
347351 @property
@@ -352,14 +356,19 @@ def nontrainable_weights(self):
352356
353357 """
354358
355- layers = self .layers_and_names (name_prefix = '' )
356- for layer_name , layer in layers :
357- params = layer ._params .items ()
358- params_status = layer ._params_status .items ()
359- params_zip = zip (params , params_status )
360- for params , params_status in params_zip :
361- if params_status [1 ] == False :
362- self ._nontrainable_weights .append (params [1 ])
359+ if self ._nontrainable_weights is not None and len (self ._nontrainable_weights ) > 0 :
360+ # self._nontrainable_weights already extracted, so do nothing
361+ pass
362+ else :
363+ self ._nontrainable_weights = []
364+ layers = self .layers_and_names (name_prefix = '' )
365+ for layer_name , layer in layers :
366+ params = layer ._params .items ()
367+ params_status = layer ._params_status .items ()
368+ params_zip = zip (params , params_status )
369+ for params , params_status in params_zip :
370+ if params_status [1 ] == False :
371+ self ._nontrainable_weights .append (params [1 ])
363372 return self ._nontrainable_weights
364373
365374 @property
@@ -370,11 +379,16 @@ def all_weights(self):
370379
371380 """
372381
373- layers = self .layers_and_names (name_prefix = '' )
374- for layer_name , layer in layers :
375- params = layer ._params .items ()
376- for par , val in params :
377- self ._all_weights .append (val )
382+ if self ._all_weights is not None and len (self ._all_weights ) > 0 :
383+ # self._all_weights already extracted, so do nothing
384+ pass
385+ else :
386+ self ._all_weights = []
387+ layers = self .layers_and_names (name_prefix = '' )
388+ for layer_name , layer in layers :
389+ params = layer ._params .items ()
390+ for par , val in params :
391+ self ._all_weights .append (val )
378392 return self ._all_weights
379393
380394 def get_weights (self , expand = True ):
0 commit comments