Skip to content

Commit 782fd70

Browse files
Update the way of incrementing pruning_step.
PiperOrigin-RevId: 352878617
1 parent e4a5200 commit 782fd70

File tree

5 files changed

+41
-20
lines changed

5 files changed

+41
-20
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ py_library(
9191
# numpy dep1,
9292
# tensorflow dep1,
9393
# python/keras/utils:generic_utils tensorflow dep2,
94+
"//tensorflow_model_optimization/python/core/keras:compat",
9495
"//tensorflow_model_optimization/python/core/keras:utils",
9596
],
9697
)

tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def testPruneStopAndRestart_PreservesSparsity(self, save_restore_fn):
473473
if save_restore_fn.__name__ == '_save_restore_tf_model':
474474
return
475475

476-
begin_step, end_step = 0, 4
476+
begin_step, end_step = 1, 4
477477
params = self.params
478478
params['pruning_schedule'] = pruning_schedule.PolynomialDecay(
479479
0.2, 0.6, begin_step, end_step, 3, 1)
@@ -542,7 +542,9 @@ class PruneIntegrationCustomTrainingLoopTest(tf.test.TestCase,
542542

543543
def testPrunesModel_CustomTrainingLoop_ReachesTargetSparsity(self):
544544
pruned_model = prune.prune_low_magnitude(
545-
keras_test_utils.build_simple_dense_model())
545+
keras_test_utils.build_simple_dense_model(),
546+
pruning_schedule=pruning_schedule.ConstantSparsity(
547+
target_sparsity=0.5, begin_step=0, frequency=1))
546548

547549
batch_size = 20
548550
x_train = np.random.rand(20, 10)

tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,13 @@ def __init__(self):
6464
def on_train_begin(self, logs=None):
6565
# Collect all the prunable layers in the model.
6666
self.prunable_layers = _collect_prunable_layers(self.model)
67-
self.step = K.get_value(self.model.optimizer.iterations)
68-
69-
def on_train_batch_begin(self, batch, logs=None):
70-
tuples = []
71-
for layer in self.prunable_layers:
72-
tuples.append((layer.pruning_step, self.step))
73-
74-
K.batch_set_value(tuples)
75-
self.step = self.step + 1
67+
# If the model is newly created/initialized, set the 'pruning_step' to 0.
68+
# If the model is saved and then restored, do nothing.
69+
if self.prunable_layers[0].pruning_step == -1:
70+
tuples = []
71+
for layer in self.prunable_layers:
72+
tuples.append((layer.pruning_step, 0))
73+
K.batch_set_value(tuples)
7674

7775
def on_epoch_end(self, batch, logs=None):
7876
# At the end of every epoch, remask the weights. This ensures that when

tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks_test.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def testUpdatePruningStepsAndLogsSummaries(self):
7070
])
7171

7272
self.assertEqual(
73-
2, tf.keras.backend.get_value(pruned_model.layers[0].pruning_step))
73+
3, tf.keras.backend.get_value(pruned_model.layers[0].pruning_step))
7474
self.assertEqual(
75-
2, tf.keras.backend.get_value(pruned_model.layers[1].pruning_step))
75+
3, tf.keras.backend.get_value(pruned_model.layers[1].pruning_step))
7676

7777
self._assertLogsExist(log_dir)
7878

@@ -111,11 +111,23 @@ def testUpdatePruningStepsAndLogsSummaries_CustomTrainingLoop(self):
111111
step_callback.on_epoch_end(batch=unused_arg)
112112

113113
self.assertEqual(
114-
2, tf.keras.backend.get_value(pruned_model.layers[0].pruning_step))
114+
3, tf.keras.backend.get_value(pruned_model.layers[0].pruning_step))
115115
self.assertEqual(
116-
2, tf.keras.backend.get_value(pruned_model.layers[1].pruning_step))
116+
3, tf.keras.backend.get_value(pruned_model.layers[1].pruning_step))
117117
self._assertLogsExist(log_dir)
118118

119+
@keras_parameterized.run_all_keras_modes
120+
def testUpdatePruningStepsAndLogsSummaries_RunInference(self):
121+
pruned_model, _, _, x_train, _ = self._pruned_model_setup(
122+
custom_training_loop=True)
123+
model_output = pruned_model(x_train)
124+
del model_output
125+
126+
self.assertEqual(
127+
-1, tf.keras.backend.get_value(pruned_model.layers[0].pruning_step))
128+
self.assertEqual(
129+
-1, tf.keras.backend.get_value(pruned_model.layers[1].pruning_step))
130+
119131
@keras_parameterized.run_all_keras_modes
120132
def testPruneTrainingRaisesError_PruningStepCallbackMissing(self):
121133
pruned_model, x_train, y_train = self._pruned_model_setup()

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@
2626

2727
# b/(139939526): update to use public API.
2828
from tensorflow.python.keras.utils import generic_utils
29-
29+
from tensorflow_model_optimization.python.core.keras import compat as tf_compat
3030
from tensorflow_model_optimization.python.core.keras import utils
31-
3231
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
3332
from tensorflow_model_optimization.python.core.sparsity.keras import prune_registry
3433
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_impl
@@ -234,11 +233,14 @@ def call(self, inputs, training=None):
234233
if training is None:
235234
training = K.learning_phase()
236235

236+
def increment_step():
237+
return tf_compat.assign(self.pruning_step, self.pruning_step + 1)
238+
237239
def add_update():
238240
with tf.control_dependencies([
239241
tf.debugging.assert_greater_equal(
240242
self.pruning_step,
241-
np.int64(0),
243+
np.int64(1),
242244
message=self._PRUNE_CALLBACK_ERROR_MSG)
243245
]):
244246
with tf.control_dependencies(
@@ -248,8 +250,14 @@ def add_update():
248250
def no_op():
249251
return tf.no_op('no_update')
250252

251-
update_op = utils.smart_cond(training, add_update, no_op)
252-
self.add_update(update_op)
253+
# Increment the 'pruning_step' after each step.
254+
update_pruning_step = utils.smart_cond(training, increment_step, no_op)
255+
self.add_update(update_pruning_step)
256+
257+
# Update mask tensor after each 'pruning_frequency' steps.
258+
update_mask = utils.smart_cond(training, add_update, no_op)
259+
self.add_update(update_mask)
260+
253261
# Always execute the op that performs weights = weights * mask
254262
# Relies on UpdatePruningStep callback to ensure the weights
255263
# are sparse after the final backpropagation.

0 commit comments

Comments
 (0)