@@ -365,6 +365,13 @@ def testPrunesSingleLayer_ReachesTargetSparsity(self, layer_type):
365365 'input_shape' : [(8 )],
366366 'm_by_n' : (1 , 2 ),
367367 },
368+ {
369+ 'testcase_name' : 'DepthwiseConv_2by4' ,
370+ 'layer_type' : tf .keras .layers .DepthwiseConv2D ,
371+ 'layer_arg' : [3 ],
372+ 'input_shape' : (7 , 7 , 32 ),
373+ 'm_by_n' : (2 , 4 ),
374+ },
368375 )
369376
370377 def testMbyNSparsityPruning_SupportedLayers (self ,
@@ -392,18 +399,45 @@ def testMbyNSparsityPruning_SupportedLayers(self,
392399 test_utils .assert_model_sparsity_m_by_n (self , model , m_by_n )
393400 self ._check_strip_pruning_matches_original (model , sparsity_ratio )
394401
395- def testSparsityPruningMbyN_NonSupportedLayers (self ):
396- """Check layer that is not supported for m by n sparsity."""
397- self .params .update ({'sparsity_m_by_n' : (2 , 4 )})
398-
399- model = keras .Sequential ()
400- layer_type = tf .keras .layers .SeparableConv1D
401- args , input_shape = ([4 , 3 ], (3 , 6 ))
402+ def testSparsityPruningMbyN_SupportedSubclassLayers (self ):
403+ """Check subclass layer that is supported for m by n sparsity."""
404+ m_by_n = (2 , 4 )
405+ self .params .update ({'sparsity_m_by_n' : m_by_n })
402406
407+ class SubclassLayer (tf .keras .layers .Layer ):
408+
409+ def __init__ (self ):
410+ super (SubclassLayer , self ).__init__ ()
411+ self .conv1 = tf .keras .layers .Conv2D (
412+ 2 , 3 , activation = 'relu' , padding = 'same' , input_shape = [7 , 7 , 3 ])
413+ self .conv2 = tf .keras .layers .DepthwiseConv2D (3 )
414+ self .flatten = keras .layers .Flatten ()
415+ self .dense = layers .Dense (10 , activation = 'sigmoid' )
416+
417+ def call (self , inputs ):
418+ x = self .conv1 (inputs )
419+ x = self .conv2 (x )
420+ x = self .flatten (x )
421+ x = self .dense (x )
422+ return x
423+
424+ inputs = keras .Input (shape = (7 , 7 , 3 ))
425+ outputs = SubclassLayer ()(inputs )
426+ model = keras .Model (inputs , outputs )
403427 with self .assertRaises (ValueError ):
404- model .add (
405- prune .prune_low_magnitude (
406- layer_type (* args ), input_shape = input_shape , ** self .params ))
428+ model = prune .prune_low_magnitude (model , ** self .params )
429+
430+ model .compile (
431+ loss = 'categorical_crossentropy' , optimizer = 'sgd' , metrics = ['accuracy' ])
432+
433+ test_utils .assert_model_sparsity (self , 0.0 , model )
434+ model .fit (
435+ np .random .randn (* self ._batch (model .input .get_shape ().as_list (), 32 )),
436+ np .random .randn (* self ._batch (model .output .get_shape ().as_list (), 32 )),
437+ callbacks = [pruning_callbacks .UpdatePruningStep ()])
438+
439+ test_utils .assert_model_sparsity_m_by_n (self , model , m_by_n )
440+ self ._check_strip_pruning_matches_original (model , 0.5 )
407441
408442 @parameterized .parameters (prune_registry .PruneRegistry ._RNN_LAYERS -
409443 {keras .layers .RNN })
0 commit comments