Skip to content

Commit 65bcf47

Browse files
rino20tensorflower-gardener
authored andcommitted
Update pruning api examples to use PolynomialDecay instead of ConstantSparsity.
PiperOrigin-RevId: 402817279
1 parent d53ff08 commit 65bcf47

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/estimator_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ==============================================================================
1515
"""Utility functions for making pruning wrapper work with estimators."""
1616

17-
import tensorflow as tf
17+
import tensorflow.compat.v1 as tf
1818

1919
from tensorflow.python.framework import ops
2020
from tensorflow_model_optimization.python.core.sparsity.keras.pruning_wrapper import PruneLowMagnitude
@@ -32,6 +32,12 @@ def __new__(cls, model, step=None, train_op=None, **kwargs):
3232
raise ValueError(
3333
"Must provide train_op for creating a PruningEstimatorSpec")
3434

35+
for layer in model.layers:
36+
# If the model is newly created/initialized, set the 'pruning_step' to 0.
37+
# Otherwise, do nothing.
38+
if isinstance(layer, PruneLowMagnitude) and layer.pruning_step == -1:
39+
tf.assign(layer.pruning_step, 0)
40+
3541
def _get_step_increment_ops(model, step=None):
3642
"""Returns ops to increment the pruning_step in the prunable layers."""
3743
increment_ops = []
@@ -43,7 +49,7 @@ def _get_step_increment_ops(model, step=None):
4349
increment_ops.append(tf.assign_add(layer.pruning_step, 1))
4450
else:
4551
increment_ops.append(
46-
tf.assign(layer.pruning_step, tf.cast(step, tf.int32)))
52+
tf.assign(layer.pruning_step, tf.cast(step, tf.int64)))
4753

4854
return tf.group(increment_ops)
4955

tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_cnn.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
2626
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
2727

28-
ConstantSparsity = pruning_schedule.ConstantSparsity
28+
PolynomialDecay = pruning_schedule.PolynomialDecay
2929
keras = tf.keras
3030
l = keras.layers
3131

@@ -159,7 +159,13 @@ def main(unused_argv):
159159
y_test = tf.keras.utils.to_categorical(y_test, num_classes)
160160

161161
pruning_params = {
162-
'pruning_schedule': ConstantSparsity(0.75, begin_step=2000, frequency=100)
162+
'pruning_schedule':
163+
PolynomialDecay(
164+
initial_sparsity=0.1,
165+
final_sparsity=0.75,
166+
begin_step=1000,
167+
end_step=5000,
168+
frequency=100)
163169
}
164170

165171
layerwise_model = build_layerwise_model(input_shape, **pruning_params)

0 commit comments

Comments
 (0)