Skip to content

Commit d942a15

Browse files
fredrectensorflower-gardener
authored andcommitted
Add API and command line tool to prune Keras models, without retraining.
Quickly produces pruned models, without concern for accuracy. Useful to evaluate the performance benefits of given pruning parameters, without time-consuming retraining. PiperOrigin-RevId: 372284927
1 parent 1eaa343 commit d942a15

File tree

7 files changed

+377
-15
lines changed

7 files changed

+377
-15
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,6 @@
3030
callbacks = tf.keras.callbacks
3131

3232

33-
def _collect_prunable_layers(model):
34-
"""Recursively collect the prunable layers in the model."""
35-
prunable_layers = []
36-
for layer in model.layers:
37-
# A keras model may have other models as layers.
38-
if isinstance(layer, tf.keras.Model):
39-
prunable_layers += _collect_prunable_layers(layer)
40-
if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
41-
prunable_layers.append(layer)
42-
43-
return prunable_layers
44-
45-
4633
class UpdatePruningStep(callbacks.Callback):
4734
"""Keras callback which updates pruning wrappers with the optimizer step.
4835
@@ -63,7 +50,7 @@ def __init__(self):
6350

6451
def on_train_begin(self, logs=None):
6552
# Collect all the prunable layers in the model.
66-
self.prunable_layers = _collect_prunable_layers(self.model)
53+
self.prunable_layers = pruning_wrapper.collect_prunable_layers(self.model)
6754
if not self.prunable_layers:
6855
return
6956
# If the model is newly created/initialized, set the 'pruning_step' to 0.
@@ -125,7 +112,7 @@ def on_epoch_begin(self, epoch, logs=None):
125112

126113
pruning_logs = {}
127114
params = []
128-
prunable_layers = _collect_prunable_layers(self.model)
115+
prunable_layers = pruning_wrapper.collect_prunable_layers(self.model)
129116
for layer in prunable_layers:
130117
for _, mask, threshold in layer.pruning_vars:
131118
params.append(mask)

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,3 +350,11 @@ def get_weights(self):
350350

351351
def set_weights(self, weights):
352352
self.layer.set_weights(weights)
353+
354+
355+
def collect_prunable_layers(model):
356+
"""Recursively collect the prunable layers in the model."""
357+
return [
358+
layer for layer in model.submodules
359+
if isinstance(layer, PruneLowMagnitude)
360+
]
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
load("//tensorflow_model_optimization:tensorflow_model_optimization.bzl", "py_strict_library")
2+
3+
package(default_visibility = [
4+
"//tensorflow_model_optimization:__subpackages__",
5+
])
6+
7+
licenses(["notice"])
8+
9+
py_strict_library(
10+
name = "sparsity_tooling",
11+
srcs = ["sparsity_tooling.py"],
12+
srcs_version = "PY3",
13+
visibility = ["//visibility:public"],
14+
deps = [
15+
# tensorflow dep1,
16+
"//tensorflow_model_optimization/python/core/sparsity/keras:prune",
17+
"//tensorflow_model_optimization/python/core/sparsity/keras:pruning_schedule",
18+
"//tensorflow_model_optimization/python/core/sparsity/keras:pruning_wrapper",
19+
],
20+
)
21+
22+
py_test(
23+
name = "sparsity_tooling_test",
24+
size = "medium",
25+
srcs = ["sparsity_tooling_test.py"],
26+
python_version = "PY3",
27+
visibility = ["//visibility:public"],
28+
deps = [
29+
":sparsity_tooling",
30+
# absl/testing:parameterized dep1,
31+
# tensorflow dep1,
32+
"//tensorflow_model_optimization/python/core/keras:compat",
33+
"//tensorflow_model_optimization/python/core/sparsity/keras:test_utils",
34+
],
35+
)
36+
37+
py_binary(
38+
name = "evaluate_pruning",
39+
srcs = ["evaluate_pruning.py"],
40+
python_version = "PY3",
41+
deps = [
42+
":sparsity_tooling",
43+
# tensorflow dep1,
44+
"//tensorflow_model_optimization/python/core/sparsity/keras:prune",
45+
],
46+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tool to quickly prune a Keras model for evaluation purpose.
16+
17+
Prunes the model with the given spasity parameters, without retraining. Will
18+
output a converted TFLite model for both pruned and unpruned versions.
19+
20+
This tool is intented to produce sparsified models for evaluating the
21+
performance benefits (model size, inference time, …) of pruning. Since the
22+
sparsity is applied in one shot, without retrainig, the accuracy of the
23+
resulting model will be severly degraded.
24+
"""
25+
26+
from __future__ import print_function
27+
28+
import os
29+
import tempfile
30+
import textwrap
31+
import zipfile
32+
33+
from absl import app
34+
from absl import flags
35+
import tensorflow as tf
36+
37+
from tensorflow_model_optimization.python.core.sparsity.keras import prune
38+
from tensorflow_model_optimization.python.core.sparsity.keras.tools import sparsity_tooling
39+
40+
41+
_MODEL_PATH = flags.DEFINE_string('model', None, 'Keras model file to prune')
42+
_OUTPUT_DIR = flags.DEFINE_string('output_dir', None, 'Output directory')
43+
_SPARSITY = flags.DEFINE_float(
44+
'sparsity',
45+
0.8,
46+
'Target sparsity level, as float in [0,1] interval',
47+
lower_bound=0,
48+
upper_bound=1)
49+
_BLOCK_SIZE = flags.DEFINE_string(
50+
'block_size', '1,1',
51+
'Comma-separated dimensions (height,weight) of the block sparsity pattern.'
52+
)
53+
54+
55+
def _parse_block_size_flag(value):
56+
height_str, weight_str = value.split(',')
57+
return int(height_str), int(weight_str)
58+
59+
60+
@flags.validator(_BLOCK_SIZE.name)
61+
def _check_block_size(flag_value):
62+
try:
63+
_parse_block_size_flag(flag_value)
64+
return True
65+
except:
66+
raise flags.ValidationError('Invalid block size value "%s".' % flag_value)
67+
68+
69+
def convert_to_tflite(keras_model, output_path):
70+
"""Converts the given Keras model to TFLite and write it to a file."""
71+
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
72+
converter.optimizations = {tf.lite.Optimize.EXPERIMENTAL_SPARSITY}
73+
74+
with open(output_path, 'wb') as out:
75+
out.write(converter.convert())
76+
77+
78+
def get_gzipped_size(model_path):
79+
"""Measures the compressed size of a model."""
80+
with tempfile.TemporaryFile(suffix='.zip') as zipped_file:
81+
with zipfile.ZipFile(
82+
zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
83+
f.write(model_path)
84+
85+
zipped_file.seek(0, 2)
86+
return os.fstat(zipped_file.fileno()).st_size
87+
88+
89+
def pruned_model_filename(sparsity, block_size):
90+
"""Produces a human-readable name including sparsity parameters."""
91+
return 'pruned_model_sparsity_%.2f_block_%s.tflite' % (
92+
sparsity, '%dx%d' % block_size)
93+
94+
95+
def run(input_model_path, output_dir, target_sparsity, block_size):
96+
"""Prunes the model and converts both pruned and unpruned versions to TFLite."""
97+
98+
print(textwrap.dedent("""\
99+
Warning: The sparse models produced by this tool have poor accuracy. They
100+
are not intended to be served in production, but to be used for
101+
performance benchmarking."""))
102+
103+
input_model = tf.keras.models.load_model(input_model_path)
104+
105+
os.makedirs(output_dir, exist_ok=True)
106+
unpruned_tflite_path = os.path.join(
107+
output_dir, 'unpruned_model.tflite')
108+
pruned_tflite_path = os.path.join(
109+
output_dir, pruned_model_filename(target_sparsity, block_size))
110+
111+
# Convert to TFLite without pruning
112+
convert_to_tflite(input_model, unpruned_tflite_path)
113+
114+
# Prune and convert to TFLite
115+
pruned_model = sparsity_tooling.prune_for_benchmark(
116+
keras_model=input_model,
117+
target_sparsity=target_sparsity,
118+
block_size=block_size)
119+
stripped_model = prune.strip_pruning(pruned_model) # Remove pruning wrapper
120+
convert_to_tflite(stripped_model, pruned_tflite_path)
121+
122+
# Measure the compressed size of unpruned vs pruned TFLite models
123+
unpruned_compressed_size = get_gzipped_size(unpruned_tflite_path)
124+
pruned_compressed_size = get_gzipped_size(pruned_tflite_path)
125+
print('Size of gzipped TFLite models:')
126+
print(' * Unpruned : %.2fMiB' % (unpruned_compressed_size / (2.**20)))
127+
print(' * Pruned : %.2fMiB' % (pruned_compressed_size / (2.**20)))
128+
print(' diff : %d%%' %
129+
(100. * (pruned_compressed_size - unpruned_compressed_size) /
130+
unpruned_compressed_size))
131+
132+
133+
def main(argv):
134+
if len(argv) > 1:
135+
raise app.UsageError('Too many command-line arguments.')
136+
137+
block_size = _parse_block_size_flag(_BLOCK_SIZE.value)
138+
run(_MODEL_PATH.value, _OUTPUT_DIR.value, _SPARSITY.value, block_size)
139+
140+
141+
if __name__ == '__main__':
142+
flags.mark_flag_as_required(_MODEL_PATH.name)
143+
flags.mark_flag_as_required(_OUTPUT_DIR.name)
144+
145+
app.run(main)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Utilities to prune without training.
16+
17+
Quickly produces pruned models, with no concern for accuracy. Useful to
18+
evaluate the performance benefits of given pruning parameters, without
19+
time-consuming retraining.
20+
"""
21+
22+
from __future__ import absolute_import
23+
from __future__ import division
24+
from __future__ import print_function
25+
26+
import tensorflow as tf
27+
28+
from tensorflow_model_optimization.python.core.sparsity.keras import prune
29+
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
30+
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
31+
32+
keras = tf.keras
33+
34+
35+
class StepIndependentConstantSparsity(pruning_schedule.PruningSchedule):
36+
"""Pruning schedule with constant sparsity, applied at any step."""
37+
38+
def __init__(self, target_sparsity):
39+
"""Initializes a Pruning schedule with constant sparsity.
40+
41+
Sparsity is applied at every step.
42+
43+
Args:
44+
target_sparsity: Target sparsity as float, in [0, 1] interval.
45+
"""
46+
self.target_sparsity = target_sparsity
47+
48+
def __call__(self, step):
49+
return (True, tf.constant(self.target_sparsity, dtype=tf.float32))
50+
51+
def get_config(self):
52+
return {
53+
'class_name': self.__class__.__name__,
54+
'config': {
55+
'target_sparsity': self.target_sparsity,
56+
}
57+
}
58+
59+
60+
def _apply_pruning(prunable_object):
61+
"""Calculates the masks and updates weights of layers of a wrapped model."""
62+
assert tf.executing_eagerly()
63+
for layer in pruning_wrapper.collect_prunable_layers(prunable_object):
64+
layer.pruning_obj.conditional_mask_update() # Create mask
65+
layer.pruning_obj.weight_mask_op() # weight = weight * mask
66+
67+
68+
def prune_for_benchmark(keras_model,
69+
target_sparsity,
70+
block_size=(1, 1)):
71+
"""Prunes a tf.keras model in a single step, without re-training.
72+
73+
This function is intented to quickly apply sparsity to a model, without
74+
consideration for accuracy.
75+
76+
Args:
77+
keras_model: A `tf.keras.Model` instance.
78+
target_sparsity: Target sparsity as float, in [0, 1] interval.
79+
block_size: The dimensions (height, weight) for the block sparse
80+
pattern in rank-2 weight tensors.
81+
Returns:
82+
A pruned model, modified with pruning wrappers.
83+
"""
84+
85+
pruning_params = {
86+
'pruning_schedule': StepIndependentConstantSparsity(target_sparsity),
87+
'block_size': block_size,
88+
}
89+
90+
prunable_object = prune.prune_low_magnitude(keras_model, **pruning_params)
91+
_apply_pruning(prunable_object)
92+
93+
return prunable_object

0 commit comments

Comments
 (0)