Skip to content

Commit 6be78ae

Browse files
Fix the pruning policy to handle models after cloning
PiperOrigin-RevId: 375601956
1 parent 137850e commit 6be78ae

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@ def allow_pruning(self, layer):
9494
return isinstance(layer, layers.Conv2D) and layer.kernel_size == (1, 1)
9595

9696
def _get_producers(self, layer):
97-
producers = []
97+
producers = set()
9898
for node in layer._inbound_nodes:
9999
if isinstance(node.inbound_layers, list):
100-
producers.extend(node.inbound_layers)
100+
producers.update(node.inbound_layers)
101101
else:
102-
producers.append(node.inbound_layers)
102+
producers.add(node.inbound_layers)
103103
return producers
104104

105105
def _get_consumers(self, layer):
@@ -133,7 +133,7 @@ def _lookup_layers(self, source_layers, stop_fn, next_fn):
133133
def _start_layer_stop_fn(self, layer):
134134
"""Determines whether the layer starts a subgraph of sparse inference."""
135135
if isinstance(layer, layers.Conv2D):
136-
producers = self._get_producers(layer)
136+
producers = list(self._get_producers(layer))
137137
return (hasattr(layer, 'kernel') and
138138
layer.kernel.shape[:3] == (3, 3, 3) and
139139
layer.strides == (2, 2) and layer.padding.lower() == 'valid' and
@@ -172,7 +172,7 @@ def _check_layer_support(self, layer):
172172
# 5x5 convolution with `VALID` padding (no dilation, stride-1 or stride-2,
173173
# preceding `ZeroPadding2D` layer with padding 2 on each side.
174174
padding = layer.padding.lower()
175-
producers = self._get_producers(layer)
175+
producers = list(self._get_producers(layer))
176176
zero_padding = (
177177
producers[0] if len(producers) == 1 and
178178
isinstance(producers[0], layers.ZeroPadding2D) else None)

tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,29 @@ def testFunctionalModelForLatencyOnXNNPackPolicy(self):
278278
)
279279
self.assertEqual(self._count_pruned_layers(pruned_model), 6)
280280

281+
def testPruneFunctionalModelAfterCloneForLatencyOnXNNPackPolicy(self):
282+
i = keras.Input(shape=(8, 8, 3))
283+
x = layers.ZeroPadding2D(padding=1)(i)
284+
x = layers.Conv2D(
285+
filters=16,
286+
kernel_size=(3, 3),
287+
strides=(2, 2),
288+
padding='valid',
289+
)(
290+
x)
291+
x = layers.Conv2D(filters=16, kernel_size=[1, 1])(x)
292+
o = layers.GlobalAveragePooling2D(keepdims=True)(x)
293+
original_model = keras.Model(inputs=[i], outputs=[o])
294+
295+
cloned_model = tf.keras.models.clone_model(
296+
original_model, clone_function=lambda l: l)
297+
pruned_model = prune.prune_low_magnitude(
298+
cloned_model,
299+
pruning_policy=pruning_policy.PruneForLatencyOnXNNPack(),
300+
**self.params,
301+
)
302+
self.assertEqual(self._count_pruned_layers(pruned_model), 1)
303+
281304

282305
if __name__ == '__main__':
283306
tf.test.main()

0 commit comments

Comments
 (0)