Skip to content

Commit 3535acd

Browse files
Merge pull request #517 from wwwind:bug_inconsistency_weights_name
PiperOrigin-RevId: 328666954
2 parents 82b698d + 307ebd8 commit 3535acd

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def _strip_clustering_wrapper(layer):
196196
layer.layer._non_trainable_weights = []
197197
for i in range(len(layer.restore)):
198198
# This is why we used integers as keys
199-
name, weight = layer.restore[i]
199+
name, weight_name, weight = layer.restore[i]
200200
# In both cases we use k.batch_get_value since we need physical copies
201201
# of the arrays to initialize a new tensor
202202
if i in layer.gone_variables:
@@ -209,7 +209,7 @@ def _strip_clustering_wrapper(layer):
209209
new_weight_value = k.batch_get_value([weight])[0]
210210
setattr(layer.layer,
211211
name,
212-
k.variable(new_weight_value, name=name))
212+
k.variable(new_weight_value, name=weight_name))
213213
# When all weights are filled with the values, just return the underlying
214214
# layer since it is now fully autonomous from its wrapper
215215
return layer.layer

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,13 @@ def fn():
244244

245245
for ct, weight in enumerate(self.layer.weights):
246246
name = self._weight_name(weight.name)
247+
full_name = self.layer.name + "/" + name
247248
if ct in self.gone_variables:
248249
# Again, not sure if this is needed
249250
weight_name = clusterable_weights_to_variables[name]
250-
self.restore.append((name, get_updater(weight_name)))
251+
self.restore.append((name, full_name, get_updater(weight_name)))
251252
else:
252-
self.restore.append((name, weight))
253+
self.restore.append((name, full_name, weight))
253254

254255
def call(self, inputs):
255256
# Go through all tensors and replace them with their clustered copies.

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,12 @@ def testValuesAreClusteredAfterStripping(self,
170170
or equal to number_of_clusters.
171171
"""
172172
original_model = tf.keras.Sequential([
173-
layers.Dense(32, input_shape=(10,)),
173+
layers.Dense(32, input_shape=(10,), name='abc'),
174174
])
175+
176+
weights_name = original_model.layers[0].weights[0].name
177+
bias_name = original_model.layers[0].weights[1].name
178+
175179
clustered_model = cluster.cluster_weights(
176180
original_model,
177181
number_of_clusters=number_of_clusters,
@@ -186,6 +190,9 @@ def testValuesAreClusteredAfterStripping(self,
186190
# Make sure that the stripped layer is the Dense one
187191
self.assertIsInstance(stripped_model.layers[0], layers.Dense)
188192

193+
# Check that we keep names for weights/bias
194+
self.assertEqual(stripped_model.layers[0].weights[0].name, weights_name)
195+
self.assertEqual(stripped_model.layers[0].weights[1].name, bias_name)
189196

190197
if __name__ == '__main__':
191198
test.main()

0 commit comments

Comments
 (0)