@@ -52,13 +52,13 @@ def prune_low_magnitude(to_prune,
52
52
** kwargs ):
53
53
"""Modify a keras layer or model to be pruned during training.
54
54
55
- This function wraps a keras model or layer with pruning functionality which
55
+ This function wraps a tf. keras model or layer with pruning functionality which
56
56
sparsifies the layer's weights during training. For example, using this with
57
57
50% sparsity will ensure that 50% of the layer's weights are zero.
58
58
59
59
The function accepts either a single keras layer
60
- (subclass of `keras.layers.Layer`), list of keras layers or a keras model
61
- (instance of ` keras.models.Model`) and handles them appropriately.
60
+ (subclass of `tf. keras.layers.Layer`), list of keras layers or a Sequential
61
+ or Functional keras model and handles them appropriately.
62
62
63
63
If it encounters a layer it does not know how to handle, it will throw an
64
64
error. While pruning an entire model, even a single unknown layer would lead
@@ -144,15 +144,28 @@ def _add_pruning_wrapper(layer):
144
144
'block_size' : block_size ,
145
145
'block_pooling_type' : block_pooling_type
146
146
}
147
+ is_sequential_or_functional = isinstance (
148
+ to_prune , keras .Model ) and (isinstance (to_prune , keras .Sequential ) or
149
+ to_prune ._is_graph_network )
150
+
151
+ # A subclassed model is also a subclass of keras.layers.Layer.
152
+ is_keras_layer = isinstance (
153
+ to_prune , keras .layers .Layer ) and not isinstance (to_prune , keras .Model )
147
154
148
155
if isinstance (to_prune , list ):
149
156
return _prune_list (to_prune , ** params )
150
- elif isinstance ( to_prune , keras . Model ) :
157
+ elif is_sequential_or_functional :
151
158
return keras .models .clone_model (
152
159
to_prune , input_tensors = None , clone_function = _add_pruning_wrapper )
153
- elif isinstance ( to_prune , keras . layers . Layer ) :
160
+ elif is_keras_layer :
154
161
params .update (kwargs )
155
162
return pruning_wrapper .PruneLowMagnitude (to_prune , ** params )
163
+ else :
164
+ raise ValueError (
165
+ '`prune_low_magnitude` can only prune an object of the following '
166
+ 'types: tf.keras.models.Sequential, tf.keras functional model, '
167
+ 'tf.keras.layers.Layer, list of tf.keras.layers.Layer. You passed '
168
+ 'an object of type: {input}.' .format (input = to_prune .__class__ .__name__ ))
156
169
157
170
158
171
def strip_pruning (model ):
0 commit comments