Skip to content

Commit 9fff2f3

Browse files
Make the PruneForLatencyOnXNNPack prune policy more robust
PiperOrigin-RevId: 373696218
1 parent 3ad296c commit 9fff2f3

File tree

2 files changed

+57
-17
lines changed

2 files changed

+57
-17
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,12 @@ def _get_producers(self, layer):
103103
return producers
104104

105105
def _get_consumers(self, layer):
106-
return [node.outbound_layer for node in layer._outbound_nodes]
106+
107+
def unpack(layer):
108+
return (unpack(layer.layers[0])
109+
if isinstance(layer, tf.keras.Sequential) else layer)
110+
111+
return [unpack(node.outbound_layer) for node in layer._outbound_nodes]
107112

108113
def _lookup_layers(self, source_layers, stop_fn, next_fn):
109114
"""Traverses the model and returns layers satisfying `stop_fn` criteria."""
@@ -127,9 +132,15 @@ def _lookup_layers(self, source_layers, stop_fn, next_fn):
127132

128133
def _start_layer_stop_fn(self, layer):
129134
"""Determines whether the layer starts a subgraph of sparse inference."""
130-
return (isinstance(layer, layers.Conv2D) and hasattr(layer, 'kernel') and
131-
layer.kernel.shape[:3] == (3, 3, 3) and layer.strides == (2, 2) and
132-
layer.padding.lower() == 'valid')
135+
if isinstance(layer, layers.Conv2D):
136+
producers = self._get_producers(layer)
137+
return (hasattr(layer, 'kernel') and
138+
layer.kernel.shape[:3] == (3, 3, 3) and
139+
layer.strides == (2, 2) and layer.padding.lower() == 'valid' and
140+
len(producers) == 1 and
141+
isinstance(producers[0], layers.ZeroPadding2D) and
142+
producers[0].padding == ((1, 1), (1, 1)))
143+
return False
133144

134145
def _end_layer_stop_fn(self, layer):
135146
"""Determines whether the layer ends a subgraph of sparse inference."""
@@ -157,10 +168,36 @@ def _check_layer_support(self, layer):
157168
# 3x3 stride-2 convolution (no dilation, padding 1 on each side).
158169
# 5x5 stride-1 convolution (no dilation, padding 2 on each side).
159170
# 5x5 stride-2 convolution (no dilation, padding 2 on each side).
160-
return (layer.depth_multiplier == 1 and layer.dilation_rate == (1, 1) and
161-
(layer.kernel_size == (3, 3) or layer.kernel_size == (5, 5)) and
162-
((layer.padding.lower() == 'same' and layer.strides == (1, 1)) or
163-
(layer.padding.lower() == 'valid' and layer.strides == (2, 2))))
171+
padding = layer.padding.lower()
172+
producers = self._get_producers(layer)
173+
zero_padding = (
174+
producers[0] if len(producers) == 1 and
175+
isinstance(producers[0], layers.ZeroPadding2D) else None)
176+
177+
supported_case_1 = (
178+
layer.kernel_size == (3, 3) and layer.strides == (1, 1) and
179+
padding == 'same')
180+
181+
supported_case_2 = (
182+
layer.kernel_size == (3, 3) and layer.strides == (2, 2) and
183+
padding == 'valid' and zero_padding and
184+
zero_padding.padding == ((1, 1), (1, 1)))
185+
186+
supported_case_3 = (
187+
layer.kernel_size == (5, 5) and layer.strides == (1, 1) and
188+
padding == 'same')
189+
190+
supported_case_4 = (
191+
layer.kernel_size == (5, 5) and layer.strides == (2, 2) and
192+
padding == 'valid' and zero_padding and
193+
zero_padding.padding == ((2, 2), (2, 2)))
194+
195+
supported = (
196+
layer.depth_multiplier == 1 and layer.dilation_rate == (1, 1) and
197+
(supported_case_1 or supported_case_2 or supported_case_3 or
198+
supported_case_4))
199+
200+
return supported
164201
elif isinstance(layer, layers.Conv2D):
165202
# 1x1 convolution (no stride, no dilation, no padding, no groups).
166203
return (layer.groups == 1 and layer.dilation_rate == (1, 1) and
@@ -200,8 +237,9 @@ def ensure_model_supports_pruning(self, model):
200237
)
201238
if not start_layers:
202239
raise ValueError(('Could not find `Conv2D 3x3` layer with stride 2x2, '
203-
'`input filters == 3` and `VALID` padding in all input '
204-
'branches of the model'))
240+
'`input filters == 3` and `VALID` padding and '
241+
'preceding `ZeroPadding2D` with `padding == 1` in all '
242+
'input branches of the model'))
205243

206244
# Search for the end layer (GlobalAveragePooling with `keepdims = True`)
207245
# for every output branch (backward).

tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy_test.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727

2828
class PruningPolicyTest(tf.test.TestCase):
2929
INVALID_TO_PRUNE_START_LAYER_ERROR = (
30-
'Could not find `Conv2D 3x3` layer with stride 2x2, `input filters == 3` '
31-
'and `VALID` padding in all input branches of the model')
30+
'Could not find `Conv2D 3x3` layer with stride 2x2, `input filters == 3`'
31+
' and `VALID` padding and preceding `ZeroPadding2D` with `padding == 1` '
32+
'in all input branches of the model'
33+
)
3234

3335
INVALID_TO_PRUNE_STOP_LAYER_ERROR = (
3436
'Could not find a `GlobalAveragePooling2D` layer with `keepdims = True` '
@@ -164,18 +166,18 @@ def testPruneSequentialModelForLatencyOnXNNPackPolicy(self):
164166
self.assertEqual(self._count_pruned_layers(pruned_model), 1)
165167

166168
def testPruneModelRecursivelyForLatencyOnXNNPackPolicy(self):
167-
internal_model = keras.Sequential(
168-
[layers.ZeroPadding2D(padding=1, input_shape=(8, 8, 3))])
169169
original_model = keras.Sequential([
170-
internal_model,
170+
layers.ZeroPadding2D(padding=1, input_shape=(8, 8, 3)),
171171
layers.Conv2D(
172172
filters=4,
173173
kernel_size=(3, 3),
174174
strides=(2, 2),
175175
padding='valid',
176176
),
177-
layers.Conv2D(filters=8, kernel_size=[1, 1]),
178-
layers.Conv2D(filters=16, kernel_size=[1, 1]),
177+
keras.Sequential([
178+
layers.Conv2D(filters=8, kernel_size=[1, 1]),
179+
layers.Conv2D(filters=16, kernel_size=[1, 1]),
180+
]),
179181
layers.GlobalAveragePooling2D(keepdims=True),
180182
])
181183
pruned_model = prune.prune_low_magnitude(

0 commit comments

Comments
 (0)