Skip to content

Commit 3ebe133

Browse files
liyunlu0618alanchiao
authored andcommitted
Save mask and threshold to model checkpoint.
PiperOrigin-RevId: 246361345
1 parent 627554f commit 3ebe133

File tree

3 files changed

+60
-3
lines changed

3 files changed

+60
-3
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ py_library(
9797
# python:control_flow_ops tensorflow dep2,
9898
# python:dtypes tensorflow dep2,
9999
# python:framework_ops tensorflow dep2,
100+
# python:variables tensorflow dep2,
100101
# python/keras:backend tensorflow dep2,
101102
# python/keras:engine tensorflow dep2,
102103
# python/keras:generic_utils tensorflow dep2,

tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,49 @@ def testPruneStopAndRestartOnModel(self, save_restore_fn):
259259

260260
self._check_strip_pruning_matches_original(model, 0.6)
261261

262+
@parameterized.parameters(test_utils.save_restore_fns())
263+
def testPruneWithPolynomialDecayPreservesSparsity(self, save_restore_fn):
264+
params = {
265+
'pruning_schedule': pruning_schedule.PolynomialDecay(
266+
0.2, 0.6, 0, 1, 3, 1),
267+
'block_size': (1, 1),
268+
'block_pooling_type': 'AVG'
269+
}
270+
model = prune.prune_low_magnitude(
271+
test_utils.build_simple_dense_model(), **params)
272+
model.compile(
273+
loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
274+
# Model hasn't been trained yet. Sparsity 0.0
275+
test_utils.assert_model_sparsity(self, 0.0, model)
276+
277+
model.fit(
278+
np.random.rand(20, 10),
279+
keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
280+
batch_size=20,
281+
callbacks=[pruning_callbacks.UpdatePruningStep()])
282+
# Training has run only 1 step. Sparsity 0.2 (initial_sparsity)
283+
test_utils.assert_model_sparsity(self, 0.2, model)
284+
285+
model.fit(
286+
np.random.rand(20, 10),
287+
keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
288+
batch_size=20,
289+
callbacks=[pruning_callbacks.UpdatePruningStep()])
290+
# Training has run 2 steps. Sparsity 0.6 (final_sparsity)
291+
test_utils.assert_model_sparsity(self, 0.6, model)
292+
293+
model = save_restore_fn(model)
294+
model.fit(
295+
np.random.rand(20, 10),
296+
keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
297+
batch_size=20,
298+
epochs=2,
299+
callbacks=[pruning_callbacks.UpdatePruningStep()])
300+
# Training has run all 4 steps. Sparsity 0.6 (final_sparsity)
301+
test_utils.assert_model_sparsity(self, 0.6, model)
302+
303+
self._check_strip_pruning_matches_original(model, 0.6)
304+
262305
def testPrunesPreviouslyUnprunedModel(self):
263306
model = test_utils.build_simple_dense_model()
264307
model.compile(

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tensorflow.python.keras.utils import tf_utils
3232
from tensorflow.python.ops import check_ops
3333
from tensorflow.python.ops import control_flow_ops
34+
from tensorflow.python.ops import variables as tf_variables
3435
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
3536
from tensorflow_model_optimization.python.core.sparsity.keras import prune_registry
3637
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_impl
@@ -175,8 +176,20 @@ def build(self, input_shape):
175176

176177
# For each of the prunable weights, add mask and threshold variables
177178
for weight in self.prunable_weights:
178-
mask = K.variable(np.ones(weight.shape), dtype=weight.dtype)
179-
threshold = K.variable(0, dtype=weight.dtype)
179+
mask = self.add_variable(
180+
'mask',
181+
shape=weight.shape,
182+
initializer=initializers.get('ones'),
183+
dtype=weight.dtype,
184+
trainable=False,
185+
aggregation=tf_variables.VariableAggregation.MEAN)
186+
threshold = self.add_variable(
187+
'threshold',
188+
shape=[],
189+
initializer=initializers.get('zeros'),
190+
dtype=weight.dtype,
191+
trainable=False,
192+
aggregation=tf_variables.VariableAggregation.MEAN)
180193

181194
weight_vars.append(weight)
182195
mask_vars.append(mask)
@@ -273,7 +286,7 @@ def trainable_weights(self):
273286

274287
@property
275288
def non_trainable_weights(self):
276-
return self.layer.non_trainable_weights
289+
return self.layer.non_trainable_weights + self._non_trainable_weights
277290

278291
@property
279292
def updates(self):

0 commit comments

Comments
 (0)