File tree Expand file tree Collapse file tree 2 files changed +9
-1
lines changed
tensorflow_model_optimization/python/core/clustering/keras Expand file tree Collapse file tree 2 files changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -197,6 +197,7 @@ def _strip_clustering_wrapper(layer):
197197 for i in range (len (layer .restore )):
198198 # This is why we used integers as keys
199199 name , weight = layer .restore [i ]
200+ weight_name = layer .layer .name + "/" + name
200201 # In both cases we use k.batch_get_value since we need physical copies
201202 # of the arrays to initialize a new tensor
202203 if i in layer .gone_variables :
@@ -209,7 +210,7 @@ def _strip_clustering_wrapper(layer):
209210 new_weight_value = k .batch_get_value ([weight ])[0 ]
210211 setattr (layer .layer ,
211212 name ,
212- k .variable (new_weight_value , name = name ))
213+ k .variable (new_weight_value , name = weight_name ))
213214 # When all weights are filled with the values, just return the underlying
214215 # layer since it is now fully autonomous from its wrapper
215216 return layer .layer
Original file line number Diff line number Diff line change @@ -172,6 +172,10 @@ def testValuesAreClusteredAfterStripping(self,
172172 original_model = tf .keras .Sequential ([
173173 layers .Dense (32 , input_shape = (10 ,)),
174174 ])
175+
176+ weights_name = original_model .layers [0 ].weights [0 ].name
177+ bias_name = original_model .layers [0 ].weights [1 ].name
178+
175179 clustered_model = cluster .cluster_weights (
176180 original_model ,
177181 number_of_clusters = number_of_clusters ,
@@ -186,6 +190,9 @@ def testValuesAreClusteredAfterStripping(self,
186190 # Make sure that the stripped layer is the Dense one
187191 self .assertIsInstance (stripped_model .layers [0 ], layers .Dense )
188192
193+ # Check that we keep names for weights/bias
194+ self .assertEqual (stripped_model .layers [0 ].weights [0 ].name , weights_name )
195+ self .assertEqual (stripped_model .layers [0 ].weights [1 ].name , bias_name )
189196
190197if __name__ == '__main__' :
191198 test .main ()
You can’t perform that action at this time.
0 commit comments