Skip to content

Commit a20982e

Browse files
fredrectensorflower-gardener
authored andcommitted
Replace deprecated method add_variable() with add_weight().
PiperOrigin-RevId: 436120733
1 parent 7da537f commit a20982e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,14 @@ def build(self, input_shape):
209209

210210
# For each of the prunable weights, add mask and threshold variables
211211
for weight in self.prunable_weights:
212-
mask = self.add_variable(
212+
mask = self.add_weight(
213213
'mask',
214214
shape=weight.shape,
215215
initializer=tf.keras.initializers.get('ones'),
216216
dtype=weight.dtype,
217217
trainable=False,
218218
aggregation=tf.VariableAggregation.MEAN)
219-
threshold = self.add_variable(
219+
threshold = self.add_weight(
220220
'threshold',
221221
shape=[],
222222
initializer=tf.keras.initializers.get('zeros'),
@@ -230,7 +230,7 @@ def build(self, input_shape):
230230
self.pruning_vars = list(zip(weight_vars, mask_vars, threshold_vars))
231231

232232
# Add a scalar tracking the number of updates to the wrapped layer.
233-
self.pruning_step = self.add_variable(
233+
self.pruning_step = self.add_weight(
234234
'pruning_step',
235235
shape=[],
236236
initializer=tf.keras.initializers.Constant(-1),

0 commit comments

Comments
 (0)