diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py index 4845abcd7..2b98c6209 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py @@ -233,14 +233,14 @@ def build(self, input_shape): # For each of the prunable weights, add mask and threshold variables for weight in self.prunable_weights: mask = self.add_weight( - 'mask', + weight.name + '_mask', shape=weight.shape, initializer=tf.keras.initializers.get('ones'), dtype=weight.dtype, trainable=False, aggregation=tf.VariableAggregation.MEAN) threshold = self.add_weight( - 'threshold', + weight.name + '_threshold', shape=[], initializer=tf.keras.initializers.get('zeros'), dtype=weight.dtype,