Skip to content

Commit 7dc76a4

Browse files
committed
Addressed reviewer's comments.
Change-Id: Ie5395e5cafc3544f909294650ff1b8de4c6153f7
1 parent 53a3100 commit 7dc76a4

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,7 @@ def _strip_clustering_wrapper(layer):
196196
layer.layer._non_trainable_weights = []
197197
for i in range(len(layer.restore)):
198198
# This is why we used integers as keys
199-
name, weight = layer.restore[i]
200-
weight_name = layer.layer.name + "/" + name
199+
name, weight_name, weight = layer.restore[i]
201200
# In both cases we use k.batch_get_value since we need physical copies
202201
# of the arrays to initialize a new tensor
203202
if i in layer.gone_variables:

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,13 @@ def fn():
244244

245245
for ct, weight in enumerate(self.layer.weights):
246246
name = self._weight_name(weight.name)
247+
full_name = self.layer.name + "/" + name
247248
if ct in self.gone_variables:
248249
# Again, not sure if this is needed
249250
weight_name = clusterable_weights_to_variables[name]
250-
self.restore.append((name, get_updater(weight_name)))
251+
self.restore.append((name, full_name, get_updater(weight_name)))
251252
else:
252-
self.restore.append((name, weight))
253+
self.restore.append((name, full_name, weight))
253254

254255
def call(self, inputs):
255256
# Go through all tensors and replace them with their clustered copies.

0 commit comments

Comments
 (0)