@@ -365,6 +365,13 @@ def testPrunesSingleLayer_ReachesTargetSparsity(self, layer_type):
365
365
'input_shape' : [(8 )],
366
366
'm_by_n' : (1 , 2 ),
367
367
},
368
+ {
369
+ 'testcase_name' : 'DepthwiseConv_2by4' ,
370
+ 'layer_type' : tf .keras .layers .DepthwiseConv2D ,
371
+ 'layer_arg' : [3 ],
372
+ 'input_shape' : (7 , 7 , 32 ),
373
+ 'm_by_n' : (2 , 4 ),
374
+ },
368
375
)
369
376
370
377
def testMbyNSparsityPruning_SupportedLayers (self ,
@@ -392,18 +399,45 @@ def testMbyNSparsityPruning_SupportedLayers(self,
392
399
test_utils .assert_model_sparsity_m_by_n (self , model , m_by_n )
393
400
self ._check_strip_pruning_matches_original (model , sparsity_ratio )
394
401
395
- def testSparsityPruningMbyN_NonSupportedLayers (self ):
396
- """Check layer that is not supported for m by n sparsity."""
397
- self .params .update ({'sparsity_m_by_n' : (2 , 4 )})
398
-
399
- model = keras .Sequential ()
400
- layer_type = tf .keras .layers .SeparableConv1D
401
- args , input_shape = ([4 , 3 ], (3 , 6 ))
402
+ def testSparsityPruningMbyN_SupportedSubclassLayers (self ):
403
+ """Check subclass layer that is supported for m by n sparsity."""
404
+ m_by_n = (2 , 4 )
405
+ self .params .update ({'sparsity_m_by_n' : m_by_n })
402
406
407
+ class SubclassLayer (tf .keras .layers .Layer ):
408
+
409
+ def __init__ (self ):
410
+ super (SubclassLayer , self ).__init__ ()
411
+ self .conv1 = tf .keras .layers .Conv2D (
412
+ 2 , 3 , activation = 'relu' , padding = 'same' , input_shape = [7 , 7 , 3 ])
413
+ self .conv2 = tf .keras .layers .DepthwiseConv2D (3 )
414
+ self .flatten = keras .layers .Flatten ()
415
+ self .dense = layers .Dense (10 , activation = 'sigmoid' )
416
+
417
+ def call (self , inputs ):
418
+ x = self .conv1 (inputs )
419
+ x = self .conv2 (x )
420
+ x = self .flatten (x )
421
+ x = self .dense (x )
422
+ return x
423
+
424
+ inputs = keras .Input (shape = (7 , 7 , 3 ))
425
+ outputs = SubclassLayer ()(inputs )
426
+ model = keras .Model (inputs , outputs )
403
427
with self .assertRaises (ValueError ):
404
- model .add (
405
- prune .prune_low_magnitude (
406
- layer_type (* args ), input_shape = input_shape , ** self .params ))
428
+ model = prune .prune_low_magnitude (model , ** self .params )
429
+
430
+ model .compile (
431
+ loss = 'categorical_crossentropy' , optimizer = 'sgd' , metrics = ['accuracy' ])
432
+
433
+ test_utils .assert_model_sparsity (self , 0.0 , model )
434
+ model .fit (
435
+ np .random .randn (* self ._batch (model .input .get_shape ().as_list (), 32 )),
436
+ np .random .randn (* self ._batch (model .output .get_shape ().as_list (), 32 )),
437
+ callbacks = [pruning_callbacks .UpdatePruningStep ()])
438
+
439
+ test_utils .assert_model_sparsity_m_by_n (self , model , m_by_n )
440
+ self ._check_strip_pruning_matches_original (model , 0.5 )
407
441
408
442
@parameterized .parameters (prune_registry .PruneRegistry ._RNN_LAYERS -
409
443
{keras .layers .RNN })
0 commit comments