Skip to content

Commit e64ea0b

Browse files
rino20tensorflower-gardener
authored andcommitted
Make structured sparsity for available for functional layers
PiperOrigin-RevId: 396734906
1 parent 70fb709 commit e64ea0b

File tree

3 files changed

+53
-20
lines changed

3 files changed

+53
-20
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,13 @@ def testPrunesSingleLayer_ReachesTargetSparsity(self, layer_type):
365365
'input_shape': [(8)],
366366
'm_by_n': (1, 2),
367367
},
368+
{
369+
'testcase_name': 'DepthwiseConv_2by4',
370+
'layer_type': tf.keras.layers.DepthwiseConv2D,
371+
'layer_arg': [3],
372+
'input_shape': (7, 7, 32),
373+
'm_by_n': (2, 4),
374+
},
368375
)
369376

370377
def testMbyNSparsityPruning_SupportedLayers(self,
@@ -392,18 +399,45 @@ def testMbyNSparsityPruning_SupportedLayers(self,
392399
test_utils.assert_model_sparsity_m_by_n(self, model, m_by_n)
393400
self._check_strip_pruning_matches_original(model, sparsity_ratio)
394401

395-
def testSparsityPruningMbyN_NonSupportedLayers(self):
396-
"""Check layer that is not supported for m by n sparsity."""
397-
self.params.update({'sparsity_m_by_n': (2, 4)})
398-
399-
model = keras.Sequential()
400-
layer_type = tf.keras.layers.SeparableConv1D
401-
args, input_shape = ([4, 3], (3, 6))
402+
def testSparsityPruningMbyN_SupportedSubclassLayers(self):
403+
"""Check subclass layer that is supported for m by n sparsity."""
404+
m_by_n = (2, 4)
405+
self.params.update({'sparsity_m_by_n': m_by_n})
402406

407+
class SubclassLayer(tf.keras.layers.Layer):
408+
409+
def __init__(self):
410+
super(SubclassLayer, self).__init__()
411+
self.conv1 = tf.keras.layers.Conv2D(
412+
2, 3, activation='relu', padding='same', input_shape=[7, 7, 3])
413+
self.conv2 = tf.keras.layers.DepthwiseConv2D(3)
414+
self.flatten = keras.layers.Flatten()
415+
self.dense = layers.Dense(10, activation='sigmoid')
416+
417+
def call(self, inputs):
418+
x = self.conv1(inputs)
419+
x = self.conv2(x)
420+
x = self.flatten(x)
421+
x = self.dense(x)
422+
return x
423+
424+
inputs = keras.Input(shape=(7, 7, 3))
425+
outputs = SubclassLayer()(inputs)
426+
model = keras.Model(inputs, outputs)
403427
with self.assertRaises(ValueError):
404-
model.add(
405-
prune.prune_low_magnitude(
406-
layer_type(*args), input_shape=input_shape, **self.params))
428+
model = prune.prune_low_magnitude(model, **self.params)
429+
430+
model.compile(
431+
loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
432+
433+
test_utils.assert_model_sparsity(self, 0.0, model)
434+
model.fit(
435+
np.random.randn(*self._batch(model.input.get_shape().as_list(), 32)),
436+
np.random.randn(*self._batch(model.output.get_shape().as_list(), 32)),
437+
callbacks=[pruning_callbacks.UpdatePruningStep()])
438+
439+
test_utils.assert_model_sparsity_m_by_n(self, model, m_by_n)
440+
self._check_strip_pruning_matches_original(model, 0.5)
407441

408442
@parameterized.parameters(prune_registry.PruneRegistry._RNN_LAYERS -
409443
{keras.layers.RNN})

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,6 @@ def __init__(self,
129129
self.sparsity_m_by_n = None
130130

131131
if sparsity_m_by_n:
132-
# Sparsity m_by_n can be applied only to Conv2D and Dense layers.
133-
if (not isinstance(layer, tf.keras.layers.Conv2D) and
134-
not isinstance(layer, tf.keras.layers.Dense)):
135-
raise ValueError('Structural sparsity M by N is applicable only '
136-
'to `Conv2D` and `Dense` layers. You passed: '
137-
'{input}'.format(input=layer.__class__))
138-
139132
self.sparsity_m_by_n = convert_to_tuple_of_two_int(
140133
sparsity_m_by_n, 'sparsity_m_by_n')
141134

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,18 @@ def testCollectPrunableLayers(self):
156156

157157
self.assertLen(pruning_wrapper.collect_prunable_layers(self.model), 5)
158158

159-
def testConv3DNonPrunableWithSparsityMbyN(self):
159+
def testConv3DWeightNotPrunedWithSparsityMbyN(self):
160160
layer = keras.layers.Conv3D(2, 3)
161161
inputs = keras.layers.Input(shape=(4, 28, 28, 28, 1))
162162
_ = layer(inputs)
163-
with self.assertRaises(ValueError):
164-
pruning_wrapper.PruneLowMagnitude(layer, sparsity_m_by_n=(2, 4))
163+
self.model.add(Prune(layer, sparsity_m_by_n=(2, 4)))
164+
165+
pruned_layers = pruning_wrapper.collect_prunable_layers(self.model)
166+
167+
self.assertLen(pruned_layers, 1)
168+
# Only rank-2 (e.g, Conv2D) or rank-4 (e.g, Dense) weight are pruned with
169+
# M-by-N sparsity.
170+
self.assertLen(pruned_layers[0].pruning_vars, 0)
165171

166172

167173
if __name__ == '__main__':

0 commit comments

Comments
 (0)