Skip to content

Commit 53a3100

Browse files
committed
Fix for the bug: weights/bias name should be the same for original and stripped model.
Change-Id: Ib81f4598356cd96af224ed29e9391d2f26b6bc58
1 parent c4ad2ce commit 53a3100

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def _strip_clustering_wrapper(layer):
197197
for i in range(len(layer.restore)):
198198
# This is why we used integers as keys
199199
name, weight = layer.restore[i]
200+
weight_name = layer.layer.name + "/" + name
200201
# In both cases we use k.batch_get_value since we need physical copies
201202
# of the arrays to initialize a new tensor
202203
if i in layer.gone_variables:
@@ -209,7 +210,7 @@ def _strip_clustering_wrapper(layer):
209210
new_weight_value = k.batch_get_value([weight])[0]
210211
setattr(layer.layer,
211212
name,
212-
k.variable(new_weight_value, name=name))
213+
k.variable(new_weight_value, name=weight_name))
213214
# When all weights are filled with the values, just return the underlying
214215
# layer since it is now fully autonomous from its wrapper
215216
return layer.layer

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ def testValuesAreClusteredAfterStripping(self,
172172
original_model = tf.keras.Sequential([
173173
layers.Dense(32, input_shape=(10,)),
174174
])
175+
176+
weights_name = original_model.layers[0].weights[0].name
177+
bias_name = original_model.layers[0].weights[1].name
178+
175179
clustered_model = cluster.cluster_weights(
176180
original_model,
177181
number_of_clusters=number_of_clusters,
@@ -186,6 +190,9 @@ def testValuesAreClusteredAfterStripping(self,
186190
# Make sure that the stripped layer is the Dense one
187191
self.assertIsInstance(stripped_model.layers[0], layers.Dense)
188192

193+
# Check that we keep names for weights/bias
194+
self.assertEqual(stripped_model.layers[0].weights[0].name, weights_name)
195+
self.assertEqual(stripped_model.layers[0].weights[1].name, bias_name)
189196

190197
if __name__ == '__main__':
191198
test.main()

0 commit comments

Comments
 (0)