Skip to content

Commit a438885

Browse files
Minor PruneForLatencyOnXNNPack pruning policy improvement
PiperOrigin-RevId: 375136771
1 parent 482ab3f commit a438885

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,12 @@ def _check_layer_support(self, layer):
165165
layers.ReLU, layers.LeakyReLU, layers.ELU, layers.Dropout)):
166166
return True
167167
elif isinstance(layer, layers.DepthwiseConv2D):
168-
# 3x3 stride-1 convolution (no dilation, padding 1 on each side).
169-
# 3x3 stride-2 convolution (no dilation, padding 1 on each side).
170-
# 5x5 stride-1 convolution (no dilation, padding 2 on each side).
171-
# 5x5 stride-2 convolution (no dilation, padding 2 on each side).
168+
# 3x3 convolution with `SAME` padding (no dilation, stride-1).
169+
# 3x3 convolution with `VALID` padding (no dilation, stride-1 or stride-2,
170+
# preceding `ZeroPadding2D` layer with padding 1 on each side.
171+
# 5x5 convolution with `SAME` padding (no dilation, stride-1)
172+
# 5x5 convolution with `VALID` padding (no dilation, stride-1 or stride-2,
173+
# preceding `ZeroPadding2D` layer with padding 2 on each side.
172174
padding = layer.padding.lower()
173175
producers = self._get_producers(layer)
174176
zero_padding = (
@@ -180,7 +182,8 @@ def _check_layer_support(self, layer):
180182
padding == 'same')
181183

182184
supported_case_2 = (
183-
layer.kernel_size == (3, 3) and layer.strides == (2, 2) and
185+
layer.kernel_size == (3, 3) and
186+
(layer.strides == (1, 1) or layer.strides == (2, 2)) and
184187
padding == 'valid' and zero_padding and
185188
zero_padding.padding == ((1, 1), (1, 1)))
186189

@@ -189,7 +192,8 @@ def _check_layer_support(self, layer):
189192
padding == 'same')
190193

191194
supported_case_4 = (
192-
layer.kernel_size == (5, 5) and layer.strides == (2, 2) and
195+
layer.kernel_size == (5, 5) and
196+
(layer.strides == (1, 1) or layer.strides == (2, 2)) and
193197
padding == 'valid' and zero_padding and
194198
zero_padding.padding == ((2, 2), (2, 2)))
195199

0 commit comments

Comments
 (0)