Skip to content

Commit 1e7f263

Browse files
Reflect clustered layer's name in new layer's name
1 parent 8465279 commit 1e7f263

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def __init__(self,
6262
'Please initialize `Cluster` layer with a '
6363
'`Layer` instance. You passed: {input}'.format(input=layer))
6464

65+
if 'name' not in kwargs:
66+
kwargs['name'] = self._make_layer_name(layer)
67+
6568
if isinstance(layer, clusterable_layer.ClusterableLayer):
6669
# A user-defined custom layer
6770
super(ClusterWeights, self).__init__(layer, **kwargs)
@@ -133,6 +136,10 @@ def __init__(self,
133136
and hasattr(layer, '_batch_input_shape'):
134137
self._batch_input_shape = self.layer._batch_input_shape
135138

139+
@staticmethod
140+
def _make_layer_name(layer):
141+
return '{}_{}'.format('cluster', layer.name)
142+
136143
@staticmethod
137144
def _weight_name(name):
138145
"""Extracts the weight name from the full TensorFlow variable name.

0 commit comments

Comments
 (0)