Skip to content

Commit 089fadb

Browse files
alanchiaotensorflower-gardener
authored andcommitted
Raise error in prune_low_magnitude when unsupported to_prune object is passed in.
PiperOrigin-RevId: 284044063
1 parent cc05be8 commit 089fadb

File tree

2 files changed

+41
-9
lines changed

2 files changed

+41
-9
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/prune.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,13 @@ def prune_low_magnitude(to_prune,
5252
**kwargs):
5353
"""Modify a keras layer or model to be pruned during training.
5454
55-
This function wraps a keras model or layer with pruning functionality which
55+
This function wraps a tf.keras model or layer with pruning functionality which
5656
sparsifies the layer's weights during training. For example, using this with
5757
50% sparsity will ensure that 50% of the layer's weights are zero.
5858
5959
The function accepts either a single keras layer
60-
(subclass of `keras.layers.Layer`), list of keras layers or a keras model
61-
(instance of `keras.models.Model`) and handles them appropriately.
60+
(subclass of `tf.keras.layers.Layer`), list of keras layers or a Sequential
61+
or Functional keras model and handles them appropriately.
6262
6363
If it encounters a layer it does not know how to handle, it will throw an
6464
error. While pruning an entire model, even a single unknown layer would lead
@@ -144,15 +144,28 @@ def _add_pruning_wrapper(layer):
144144
'block_size': block_size,
145145
'block_pooling_type': block_pooling_type
146146
}
147+
is_sequential_or_functional = isinstance(
148+
to_prune, keras.Model) and (isinstance(to_prune, keras.Sequential) or
149+
to_prune._is_graph_network)
150+
151+
# A subclassed model is also a subclass of keras.layers.Layer.
152+
is_keras_layer = isinstance(
153+
to_prune, keras.layers.Layer) and not isinstance(to_prune, keras.Model)
147154

148155
if isinstance(to_prune, list):
149156
return _prune_list(to_prune, **params)
150-
elif isinstance(to_prune, keras.Model):
157+
elif is_sequential_or_functional:
151158
return keras.models.clone_model(
152159
to_prune, input_tensors=None, clone_function=_add_pruning_wrapper)
153-
elif isinstance(to_prune, keras.layers.Layer):
160+
elif is_keras_layer:
154161
params.update(kwargs)
155162
return pruning_wrapper.PruneLowMagnitude(to_prune, **params)
163+
else:
164+
raise ValueError(
165+
'`prune_low_magnitude` can only prune an object of the following '
166+
'types: tf.keras.models.Sequential, tf.keras functional model, '
167+
'tf.keras.layers.Layer, list of tf.keras.layers.Layer. You passed '
168+
'an object of type: {input}.'.format(input=to_prune.__class__.__name__))
156169

157170

158171
def strip_pruning(model):

tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@
3131
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
3232

3333

34-
class TestModel(keras.Model):
34+
class TestSubclassedModel(keras.Model):
3535
"""A model subclass."""
3636

3737
def __init__(self):
3838
"""A test subclass model with one dense layer."""
39-
super(TestModel, self).__init__(name='test_model')
39+
super(TestSubclassedModel, self).__init__(name='test_model')
4040
self.layer1 = keras.layers.Dense(10, activation='relu')
4141

4242
def call(self, inputs):
@@ -55,6 +55,13 @@ class CustomNonPrunableLayer(layers.Dense):
5555

5656
class PruneTest(test.TestCase, parameterized.TestCase):
5757

58+
INVALID_TO_PRUNE_PARAM_ERROR = ('`prune_low_magnitude` can only prune an '
59+
'object of the following types: '
60+
'tf.keras.models.Sequential, tf.keras '
61+
'functional model, tf.keras.layers.Layer, '
62+
'list of tf.keras.layers.Layer. You passed an'
63+
' object of type: {input}.')
64+
5865
def setUp(self):
5966
super(PruneTest, self).setUp()
6067

@@ -319,9 +326,21 @@ def testPruneFunctionalModelPreservesBuiltState(self):
319326
self.assertEqual(loaded_model.built, True)
320327

321328
def testPruneSubclassModel(self):
322-
model = TestModel()
323-
with self.assertRaises(ValueError):
329+
model = TestSubclassedModel()
330+
with self.assertRaises(ValueError) as e:
324331
_ = prune.prune_low_magnitude(model, **self.params)
332+
self.assertEqual(
333+
str(e.exception),
334+
self.INVALID_TO_PRUNE_PARAM_ERROR.format(input='TestSubclassedModel'))
335+
336+
def testPruneMiscObject(self):
337+
338+
model = object()
339+
with self.assertRaises(ValueError) as e:
340+
_ = prune.prune_low_magnitude(model, **self.params)
341+
self.assertEqual(
342+
str(e.exception),
343+
self.INVALID_TO_PRUNE_PARAM_ERROR.format(input='object'))
325344

326345
def testStripPruningSequentialModel(self):
327346
model = keras.Sequential([

0 commit comments

Comments
 (0)