Skip to content

Commit d630127

Browse files
1. The layer is in PruneRegistry (most keras built-in layers).
2. The layer is a `PrunableLayer` instance. 3. Newly added. The layer has a customer defined `get_prunable_weights` method. PiperOrigin-RevId: 355510603
1 parent ec778f9 commit d630127

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,10 @@ class PruneLowMagnitude(Wrapper):
6767
Custom keras layers:
6868
The pruning wrapper can also be applied to a user-defined keras layer.
6969
Such a layer may contain one or more weight tensors that may be pruned.
70-
To apply pruning wrapper to such layers, set prunable_weight_names to mark
71-
the weight tensors for pruning.
70+
To apply pruning wrapper to such layers, the layer should be a `PrunableLayer`
71+
instance or, more directly, user should define a `get_prunable_weights` method
72+
for the layer (Check the pruning_wrapper_test.CustomLayerPrunable for more
73+
details about how to define a user-defined prunable layer).
7274
7375
Sparsity function:
7476
The target sparsity for the weight tensors are set through the
@@ -141,7 +143,8 @@ def __init__(self,
141143
kwargs.update({'name': '{}_{}'.format(
142144
generic_utils.to_snake_case(self.__class__.__name__), layer.name)})
143145

144-
if isinstance(layer, prunable_layer.PrunableLayer):
146+
if isinstance(layer, prunable_layer.PrunableLayer) or hasattr(
147+
layer, 'get_prunable_weights'):
145148
# Custom layer in client code which supports pruning.
146149
super(PruneLowMagnitude, self).__init__(layer, **kwargs)
147150
elif prune_registry.PruneRegistry.supports(layer):
@@ -151,8 +154,10 @@ def __init__(self,
151154
else:
152155
raise ValueError(
153156
'Please initialize `Prune` with a supported layer. Layers should '
154-
'either be a `PrunableLayer` instance, or should be supported by the '
155-
'PruneRegistry. You passed: {input}'.format(input=layer.__class__))
157+
'either be supported by the PruneRegistry (built-in keras layers) or '
158+
'should be a `PrunableLayer` instance, or should has a customer '
159+
'defined `get_prunable_weights` method. You passed: '
160+
'{input}'.format(input=layer.__class__))
156161

157162
self._track_trackable(layer, name='layer')
158163

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,35 @@
2828
Prune = pruning_wrapper.PruneLowMagnitude
2929

3030

31+
class CustomLayer(keras.layers.Layer):
32+
"""A custom layer which is not prunable."""
33+
34+
def __init__(self, input_dim=16, output_dim=32):
35+
super(CustomLayer, self).__init__()
36+
self.weight = self.add_weight(
37+
shape=(input_dim, output_dim),
38+
initializer='random_normal',
39+
trainable=True)
40+
self.bias = self.add_weight(
41+
shape=(output_dim),
42+
initializer='zeros',
43+
trainable=True)
44+
45+
def call(self, inputs):
46+
return tf.matmul(inputs, self.weight) + self.bias
47+
48+
49+
class CustomLayerPrunable(CustomLayer):
50+
"""A prunable custom layer.
51+
52+
The layer is same with the CustomLayer except it has a 'get_prunable_weights'
53+
attribute.
54+
"""
55+
56+
def get_prunable_weights(self):
57+
return [self.weight, self.bias]
58+
59+
3160
class PruningWrapperTest(tf.test.TestCase):
3261

3362
def setUp(self):
@@ -100,6 +129,19 @@ def testPruneModel(self):
100129
'PruneLowMagnitude': pruning_wrapper.PruneLowMagnitude
101130
}).get_config())
102131

132+
def testCustomLayerNonPrunable(self):
133+
layer = CustomLayer(input_dim=16, output_dim=32)
134+
inputs = keras.layers.Input(shape=(16))
135+
_ = layer(inputs)
136+
with self.assertRaises(ValueError):
137+
pruning_wrapper.PruneLowMagnitude(layer, block_pooling_type='MAX')
138+
139+
def testCustomLayerPrunable(self):
140+
layer = CustomLayerPrunable(input_dim=16, output_dim=32)
141+
inputs = keras.layers.Input(shape=(16))
142+
_ = layer(inputs)
143+
pruning_wrapper.PruneLowMagnitude(layer, block_pooling_type='MAX')
144+
103145

104146
if __name__ == '__main__':
105147
tf.test.main()

0 commit comments

Comments
 (0)