Skip to content

Commit f87ae9d

Browse files
Introduce pruning policy instance to control what (what layers) should be pruned in the model
PiperOrigin-RevId: 373008888
1 parent d942a15 commit f87ae9d

File tree

4 files changed

+546
-1
lines changed

4 files changed

+546
-1
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/BUILD

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ py_strict_library(
5555
],
5656
)
5757

58+
py_strict_library(
59+
name = "pruning_policy",
60+
srcs = ["pruning_policy.py"],
61+
srcs_version = "PY3",
62+
visibility = ["//visibility:public"],
63+
deps = [
64+
# tensorflow dep1,
65+
],
66+
)
67+
5868
py_strict_library(
5969
name = "pruning_schedule",
6070
srcs = ["pruning_schedule.py"],
@@ -300,3 +310,18 @@ py_strict_test(
300310
"//tensorflow_model_optimization/python/core/keras:compat",
301311
],
302312
)
313+
314+
py_test(
315+
name = "pruning_policy_test",
316+
size = "medium",
317+
srcs = ["pruning_policy_test.py"],
318+
python_version = "PY3",
319+
visibility = ["//visibility:public"],
320+
deps = [
321+
":prune",
322+
":pruning_policy",
323+
":pruning_schedule",
324+
":pruning_wrapper",
325+
# tensorflow dep1,
326+
],
327+
)

tensorflow_model_optimization/python/core/sparsity/keras/prune.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def prune_low_magnitude(to_prune,
5858
pruning_schedule=pruning_sched.ConstantSparsity(0.5, 0),
5959
block_size=(1, 1),
6060
block_pooling_type='AVG',
61+
pruning_policy=None,
6162
**kwargs):
6263
"""Modify a tf.keras layer or model to be pruned during training.
6364
@@ -133,6 +134,9 @@ def prune_low_magnitude(to_prune,
133134
sparse pattern in rank-2 weight tensors.
134135
block_pooling_type: (optional) The function to use to pool weights in the
135136
block. Must be 'AVG' or 'MAX'.
137+
pruning_policy: (optional) The object that controls to which layers
138+
`PruneLowMagnitude` wrapper will be applied. This API is experimental
139+
and is subject to change.
136140
**kwargs: Additional keyword arguments to be passed to the keras layer.
137141
Ignored when to_prune is not a keras layer.
138142
@@ -173,7 +177,10 @@ def _add_pruning_wrapper(layer):
173177
layer, input_tensors=None, clone_function=_add_pruning_wrapper)
174178
if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
175179
return layer
176-
return pruning_wrapper.PruneLowMagnitude(layer, **params)
180+
if pruning_policy and not pruning_policy.allow_pruning(layer):
181+
return layer
182+
else:
183+
return pruning_wrapper.PruneLowMagnitude(layer, **params)
177184

178185
params = {
179186
'pruning_schedule': pruning_schedule,
@@ -192,6 +199,8 @@ def _add_pruning_wrapper(layer):
192199
if isinstance(to_prune, list):
193200
return _prune_list(to_prune, **params)
194201
elif is_sequential_or_functional:
202+
if pruning_policy:
203+
pruning_policy.ensure_model_supports_pruning(to_prune)
195204
return _add_pruning_wrapper(to_prune)
196205
elif is_keras_layer:
197206
params.update(kwargs)
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# pylint: disable=protected-access
16+
"""Pruning Policy classes to control application of pruning wrapper."""
17+
18+
import abc
19+
import tensorflow as tf
20+
21+
layers = tf.keras.layers
22+
activations = tf.keras.activations
23+
24+
25+
class PruningPolicy(abc.ABC):
26+
"""Specifies what layers to prune in the model.
27+
28+
PruningPolicy controls application of `PruneLowMagnitude` wrapper on per-layer
29+
basis and checks that the model contains only supported layers.
30+
PruningPolicy works together with `prune_low_magnitude` through which it
31+
provides fine-grained control over pruning in the model.
32+
33+
```python
34+
pruning_params = {
35+
'pruning_schedule': ConstantSparsity(0.5, 0),
36+
'block_size': (1, 1),
37+
'block_pooling_type': 'AVG'
38+
}
39+
40+
model = prune_low_magnitude(
41+
keras.Sequential([
42+
layers.Dense(10, activation='relu', input_shape=(100,)),
43+
layers.Dense(2, activation='sigmoid')
44+
]),
45+
pruning_policy=PruneForLatencyOnXNNPack(),
46+
**pruning_params)
47+
```
48+
49+
You can inherit this class to write your own custom pruning policy.
50+
"""
51+
52+
@abc.abstractmethod
53+
def allow_pruning(self, layer):
54+
"""Checks if pruning wrapper should be applied for the current layer.
55+
56+
Args:
57+
layer: Current layer in the model.
58+
59+
Returns:
60+
True/False, whether the pruning wrapper should be applied for the layer.
61+
"""
62+
raise NotImplementedError
63+
64+
@abc.abstractmethod
65+
def ensure_model_supports_pruning(self, model):
66+
"""Checks that the model contains only supported layers.
67+
68+
Args:
69+
model: A `tf.keras.Model` instance which is going to be pruned.
70+
71+
Raises:
72+
ValueError: if the keras model doesn't support pruning policy, i.e. keras
73+
model contains an unsupported layer.
74+
"""
75+
raise NotImplementedError
76+
77+
78+
class PruneForLatencyOnXNNPack(PruningPolicy):
79+
"""Specifies to prune only 1x1 Conv2D layers in the model.
80+
81+
PruneForLatencyOnXNNPack checks that the model contains a subgraph that can
82+
leverage XNNPACK's sparse inference and applies pruning wrapper only to
83+
Conv2D with `kernel_size = (1, 1)`.
84+
85+
Reference:
86+
- [Fast Sparse ConvNets](https://arxiv.org/abs/1911.09723)
87+
- [XNNPACK Sparse Inference](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md#sparse-inference) # pylint: disable=line-too-long
88+
"""
89+
90+
def allow_pruning(self, layer):
91+
"""Allows to prune only 1x1 Conv2D layers."""
92+
return isinstance(layer, layers.Conv2D) and layer.kernel_size == (1, 1)
93+
94+
def _get_producers(self, layer):
95+
producers = []
96+
for node in layer._inbound_nodes:
97+
if isinstance(node.inbound_layers, list):
98+
producers.extend(node.inbound_layers)
99+
else:
100+
producers.append(node.inbound_layers)
101+
return producers
102+
103+
def _get_consumers(self, layer):
104+
return [node.outbound_layer for node in layer._outbound_nodes]
105+
106+
def _lookup_layers(self, source_layers, stop_fn, next_fn):
107+
"""Traverses the model and returns layers satisfying `stop_fn` criteria."""
108+
to_visit = set(source_layers)
109+
used_layers = set(source_layers)
110+
found_layers = set()
111+
while to_visit:
112+
layer = to_visit.pop()
113+
if stop_fn(layer):
114+
found_layers.add(layer)
115+
else:
116+
next_layers = next_fn(layer)
117+
if not next_layers:
118+
return set()
119+
for next_layer in next_layers:
120+
if next_layer not in used_layers:
121+
used_layers.add(next_layer)
122+
to_visit.add(next_layer)
123+
124+
return found_layers
125+
126+
def _start_layer_stop_fn(self, layer):
127+
"""Determines whether the layer starts a subgraph of sparse inference."""
128+
return (isinstance(layer, layers.Conv2D) and hasattr(layer, 'kernel') and
129+
layer.kernel.shape[:3] == (3, 3, 3) and layer.strides == (2, 2) and
130+
layer.padding.lower() == 'valid')
131+
132+
def _end_layer_stop_fn(self, layer):
133+
"""Determines whether the layer ends a subgraph of sparse inference."""
134+
return isinstance(layer, layers.GlobalAveragePooling2D) and layer.keepdims
135+
136+
def _check_layer_support(self, layer):
137+
"""Returns whether the layer is supported or not.
138+
139+
Mimics XNNPACK's behaviour of compatibility function.
140+
141+
Args:
142+
layer: Current layer in the model.
143+
144+
Returns:
145+
True if the layer is supported, False otherwise.
146+
147+
References:
148+
- https://github.com/google/XNNPACK/blob/master/src/subgraph.c#L130
149+
"""
150+
if isinstance(layer, (layers.Add, layers.Multiply, layers.ZeroPadding2D,
151+
layers.ReLU, layers.LeakyReLU, layers.ELU)):
152+
return True
153+
elif isinstance(layer, layers.DepthwiseConv2D):
154+
# 3x3 stride-1 convolution (no dilation, padding 1 on each side).
155+
# 3x3 stride-2 convolution (no dilation, padding 1 on each side).
156+
# 5x5 stride-1 convolution (no dilation, padding 2 on each side).
157+
# 5x5 stride-2 convolution (no dilation, padding 2 on each side).
158+
return (layer.depth_multiplier == 1 and layer.dilation_rate == (1, 1) and
159+
(layer.kernel_size == (3, 3) or layer.kernel_size == (5, 5)) and
160+
((layer.padding.lower() == 'same' and layer.strides == (1, 1)) or
161+
(layer.padding.lower() == 'valid' and layer.strides == (2, 2))))
162+
elif isinstance(layer, layers.Conv2D):
163+
# 1x1 convolution (no stride, no dilation, no padding, no groups).
164+
return (layer.groups == 1 and layer.dilation_rate == (1, 1) and
165+
layer.kernel_size == (1, 1) and layer.strides == (1, 1))
166+
elif isinstance(layer, layers.GlobalAveragePooling2D):
167+
return layer.keepdims
168+
elif isinstance(layer, layers.BatchNormalization):
169+
return list(layer.axis) == [3]
170+
elif isinstance(layer, layers.UpSampling2D):
171+
return layer.interpolation == 'bilinear'
172+
elif isinstance(layer, layers.Activation):
173+
return activations.serialize(layer.activation) in ('relu', 'relu6',
174+
'leaky_relu', 'elu',
175+
'sigmoid')
176+
return False
177+
178+
def ensure_model_supports_pruning(self, model):
179+
"""Ensures that the model contains only supported layers."""
180+
181+
# Check whether the model is a subclass model.
182+
if (not model._is_graph_network and
183+
not isinstance(model, tf.keras.models.Sequential)):
184+
raise ValueError('Subclassed models are not supported currently.')
185+
186+
if not model.built:
187+
raise ValueError('Unbuilt models are not supported currently.')
188+
189+
# Gather the layers that consume model's input tensors.
190+
input_layers = set(inp._keras_history.layer for inp in model.inputs)
191+
192+
# Search for the start layer (Conv2D 3x3, `stride = (2, 2)`,
193+
# `filters = 3`, `padding = `VALID``) in every input branch (forward).
194+
start_layers = self._lookup_layers(
195+
input_layers,
196+
self._start_layer_stop_fn,
197+
self._get_consumers,
198+
)
199+
if not start_layers:
200+
raise ValueError(('Could not find `Conv2D 3x3` layer with stride 2x2, '
201+
'`input filters == 3` and `VALID` padding in all input '
202+
'branches of the model'))
203+
204+
# Search for the end layer (GlobalAveragePooling with `keepdims = True`)
205+
# for every output branch (backward).
206+
output_layers = set(inp._keras_history.layer for inp in model.outputs)
207+
end_layers = self._lookup_layers(
208+
output_layers,
209+
self._end_layer_stop_fn,
210+
self._get_producers,
211+
)
212+
if not end_layers:
213+
raise ValueError(('Could not find a `GlobalAveragePooling2D` layer with '
214+
'`keepdims = True` in all output branches'))
215+
216+
# Ensure that all layers between the start and the end layers are supported
217+
# for pruning.
218+
def visit_fn(layer):
219+
if layer not in end_layers and not self._check_layer_support(layer):
220+
raise ValueError(('Layer {layer} is not supported for the {policy} '
221+
'pruning policy'.format(
222+
layer=layer.__class__.__name__,
223+
policy=self.__class__.__name__)))
224+
return layer in end_layers
225+
226+
_ = self._lookup_layers(
227+
sum([self._get_consumers(layer) for layer in start_layers], []),
228+
visit_fn,
229+
self._get_consumers,
230+
)

0 commit comments

Comments
 (0)