File tree Expand file tree Collapse file tree 2 files changed +9
-1
lines changed
tensorflow_model_optimization/python/core/clustering/keras Expand file tree Collapse file tree 2 files changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -197,6 +197,7 @@ def _strip_clustering_wrapper(layer):
197
197
for i in range (len (layer .restore )):
198
198
# This is why we used integers as keys
199
199
name , weight = layer .restore [i ]
200
+ weight_name = layer .layer .name + "/" + name
200
201
# In both cases we use k.batch_get_value since we need physical copies
201
202
# of the arrays to initialize a new tensor
202
203
if i in layer .gone_variables :
@@ -209,7 +210,7 @@ def _strip_clustering_wrapper(layer):
209
210
new_weight_value = k .batch_get_value ([weight ])[0 ]
210
211
setattr (layer .layer ,
211
212
name ,
212
- k .variable (new_weight_value , name = name ))
213
+ k .variable (new_weight_value , name = weight_name ))
213
214
# When all weights are filled with the values, just return the underlying
214
215
# layer since it is now fully autonomous from its wrapper
215
216
return layer .layer
Original file line number Diff line number Diff line change @@ -172,6 +172,10 @@ def testValuesAreClusteredAfterStripping(self,
172
172
original_model = tf .keras .Sequential ([
173
173
layers .Dense (32 , input_shape = (10 ,)),
174
174
])
175
+
176
+ weights_name = original_model .layers [0 ].weights [0 ].name
177
+ bias_name = original_model .layers [0 ].weights [1 ].name
178
+
175
179
clustered_model = cluster .cluster_weights (
176
180
original_model ,
177
181
number_of_clusters = number_of_clusters ,
@@ -186,6 +190,9 @@ def testValuesAreClusteredAfterStripping(self,
186
190
# Make sure that the stripped layer is the Dense one
187
191
self .assertIsInstance (stripped_model .layers [0 ], layers .Dense )
188
192
193
+ # Check that we keep names for weights/bias
194
+ self .assertEqual (stripped_model .layers [0 ].weights [0 ].name , weights_name )
195
+ self .assertEqual (stripped_model .layers [0 ].weights [1 ].name , bias_name )
189
196
190
197
if __name__ == '__main__' :
191
198
test .main ()
You can’t perform that action at this time.
0 commit comments