Skip to content

Commit a3ae06a

Browse files
Add basic TF ops to the supported list of PruneForLatencyOnXNNPack policy
PiperOrigin-RevId: 376909045
1 parent 8211578 commit a3ae06a

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ def _check_layer_support(self, layer):
217217
return activations.serialize(layer.activation) in ('relu', 'relu6',
218218
'leaky_relu', 'elu',
219219
'sigmoid')
220+
elif layer.__class__.__name__ == 'TFOpLambda':
221+
return layer.function in (tf.identity, tf.__operators__.add, tf.math.add,
222+
tf.math.subtract, tf.math.multiply)
220223
return False
221224

222225
def ensure_model_supports_pruning(self, model):

tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,30 @@ def testPruneFunctionalModelAfterCloneForLatencyOnXNNPackPolicy(self):
301301
)
302302
self.assertEqual(self._count_pruned_layers(pruned_model), 1)
303303

304+
def testFunctionalModelWithTFOpsForLatencyOnXNNPackPolicy(self):
305+
i = keras.Input(shape=(8, 8, 3))
306+
x = layers.ZeroPadding2D(padding=1)(i)
307+
x = layers.Conv2D(
308+
filters=16,
309+
kernel_size=(3, 3),
310+
strides=(2, 2),
311+
padding='valid',
312+
)(x)
313+
residual = layers.Conv2D(filters=16, kernel_size=[1, 1])(x)
314+
x = x + residual
315+
x = x - residual
316+
x = x * residual
317+
x = tf.identity(x)
318+
o = layers.GlobalAveragePooling2D(keepdims=True)(x)
319+
model = keras.Model(inputs=[i], outputs=[o])
320+
321+
pruned_model = prune.prune_low_magnitude(
322+
model,
323+
pruning_policy=pruning_policy.PruneForLatencyOnXNNPack(),
324+
**self.params,
325+
)
326+
self.assertEqual(self._count_pruned_layers(pruned_model), 1)
327+
304328

305329
if __name__ == '__main__':
306330
tf.test.main()

0 commit comments

Comments
 (0)