1414# ==============================================================================
1515"""Tests for prune registry."""
1616
17+ from absl .testing import parameterized
1718import tensorflow as tf
1819
1920from tensorflow_model_optimization .python .core .sparsity .keras import prunable_layer
2425PruneRegistry = prune_registry .PruneRegistry
2526
2627
27- class PruneRegistryTest (tf .test .TestCase ):
28-
29- class CustomLayer (layers .Layer ):
30- pass
31-
32- class CustomLayerFromPrunableLayer (layers .Dense ):
33- pass
34-
35- class MinimalRNNCell (keras .layers .Layer ):
36-
37- def __init__ (self , units , ** kwargs ):
38- self .units = units
39- self .state_size = units
40- super (PruneRegistryTest .MinimalRNNCell , self ).__init__ (** kwargs )
41-
42- def build (self , input_shape ):
43- self .kernel = self .add_weight (shape = (input_shape [- 1 ], self .units ),
44- initializer = 'uniform' ,
45- name = 'kernel' )
46- self .recurrent_kernel = self .add_weight (
47- shape = (self .units , self .units ),
48- initializer = 'uniform' ,
49- name = 'recurrent_kernel' )
50- self .built = True
51-
52- def call (self , inputs , states ):
53- prev_output = states [0 ]
54- h = keras .backend .dot (inputs , self .kernel )
55- output = h + keras .backend .dot (prev_output , self .recurrent_kernel )
56- return output , [output ]
57-
58- class MinimalRNNCellPrunable (MinimalRNNCell , prunable_layer .PrunableLayer ):
59-
60- def get_prunable_weights (self ):
61- return [self .kernel , self .recurrent_kernel ]
62-
63- def testSupportsKerasPrunableLayer (self ):
64- self .assertTrue (PruneRegistry .supports (layers .Dense (10 )))
65-
66- def testSupportsKerasPrunableLayerAlias (self ):
67- # layers.Conv2D maps to layers.convolutional.Conv2D
68- self .assertTrue (PruneRegistry .supports (layers .Conv2D (10 , 5 )))
69-
70- def testSupportsKerasNonPrunableLayer (self ):
71- # Dropout is a layer known to not be prunable.
72- self .assertTrue (PruneRegistry .supports (layers .Dropout (0.5 )))
73-
74- def testDoesNotSupportKerasUnsupportedLayer (self ):
75- # ConvLSTM2D is a built-in keras layer but not supported.
76- self .assertFalse (PruneRegistry .supports (layers .ConvLSTM2D (2 , (5 , 5 ))))
77-
78- def testSupportsKerasRNNLayers (self ):
79- self .assertTrue (PruneRegistry .supports (layers .LSTM (10 )))
80- self .assertTrue (PruneRegistry .supports (layers .GRU (10 )))
81- self .assertTrue (PruneRegistry .supports (layers .SimpleRNN (10 )))
82-
83- def testSupportsKerasRNNLayerWithRNNCellsParams (self ):
84- self .assertTrue (PruneRegistry .supports (layers .RNN (layers .LSTMCell (10 ))))
85-
86- self .assertTrue (
87- PruneRegistry .supports (
88- layers .RNN ([
89- layers .LSTMCell (10 ),
90- layers .GRUCell (10 ),
91- keras .experimental .PeepholeLSTMCell (10 ),
92- layers .SimpleRNNCell (10 )
93- ])))
94-
95- def testDoesNotSupportKerasRNNLayerUnknownCell (self ):
96- self .assertFalse (PruneRegistry .supports (
97- keras .layers .RNN (PruneRegistryTest .MinimalRNNCell (32 ))))
98-
99- def testSupportsKerasRNNLayerPrunableCell (self ):
100- self .assertTrue (PruneRegistry .supports (
101- keras .layers .RNN (PruneRegistryTest .MinimalRNNCellPrunable (32 ))))
102-
103- def testDoesNotSupportCustomLayer (self ):
104- self .assertFalse (PruneRegistry .supports (PruneRegistryTest .CustomLayer ()))
105-
106- def testDoesNotSupportCustomLayerInheritedFromPrunableLayer (self ):
107- self .assertFalse (
108- PruneRegistry .supports (
109- PruneRegistryTest .CustomLayerFromPrunableLayer (10 )))
28+ class CustomLayer (layers .Layer ):
29+ pass
30+
31+
32+ class CustomLayerFromPrunableLayer (layers .Dense ):
33+ pass
34+
35+
36+ class MinimalRNNCell (keras .layers .Layer ):
37+
38+ def __init__ (self , units , ** kwargs ):
39+ self .units = units
40+ self .state_size = units
41+ super (MinimalRNNCell , self ).__init__ (** kwargs )
42+
43+ def build (self , input_shape ):
44+ self .kernel = self .add_weight (
45+ shape = (input_shape [- 1 ], self .units ),
46+ initializer = 'uniform' ,
47+ name = 'kernel' )
48+ self .recurrent_kernel = self .add_weight (
49+ shape = (self .units , self .units ),
50+ initializer = 'uniform' ,
51+ name = 'recurrent_kernel' )
52+ self .built = True
53+
54+ def call (self , inputs , states ):
55+ prev_output = states [0 ]
56+ h = keras .backend .dot (inputs , self .kernel )
57+ output = h + keras .backend .dot (prev_output , self .recurrent_kernel )
58+ return output , [output ]
59+
60+
61+ class MinimalRNNCellPrunable (MinimalRNNCell , prunable_layer .PrunableLayer ):
62+
63+ def get_prunable_weights (self ):
64+ return [self .kernel , self .recurrent_kernel ]
65+
66+
67+ class PruneRegistryTest (tf .test .TestCase , parameterized .TestCase ):
68+
69+ _PRUNE_REGISTRY_SUPPORTED_LAYERS = [
70+ # Supports basic Keras layers even though it is not prunbale.
71+ layers .Dense (10 ),
72+ layers .Conv2D (10 , 5 ),
73+ layers .Dropout (0.5 ),
74+ # Supports specific layers from experimental or compat_v1.
75+ tf .keras .layers .experimental .preprocessing .Rescaling ,
76+ tf .compat .v1 .keras .layers .BatchNormalization (),
77+ # Supports Keras RNN Layers with prunable cells.
78+ layers .LSTM (10 ),
79+ layers .GRU (10 ),
80+ layers .SimpleRNN (10 ),
81+ layers .RNN (layers .LSTMCell (10 )),
82+ layers .RNN ([
83+ layers .LSTMCell (10 ),
84+ layers .GRUCell (10 ),
85+ keras .experimental .PeepholeLSTMCell (10 ),
86+ layers .SimpleRNNCell (10 )
87+ ]),
88+ keras .layers .RNN (MinimalRNNCellPrunable (32 )),
89+ ]
90+
91+ @parameterized .parameters (_PRUNE_REGISTRY_SUPPORTED_LAYERS )
92+ def testSupportsLayer (self , layer ):
93+ self .assertTrue (PruneRegistry .supports (layer ))
94+
95+ _PRUNE_REGISTRY_UNSUPPORTED_LAYERS = [
96+ # Not support a few built-in keras layers.
97+ layers .ConvLSTM2D (2 , (5 , 5 )),
98+ # Not support RNN layers with unknown cell
99+ keras .layers .RNN (MinimalRNNCell (32 )),
100+ # Not support Custom layers, even though inherited from prunable layer.
101+ CustomLayer (),
102+ CustomLayerFromPrunableLayer (10 ),
103+ ]
104+
105+ @parameterized .parameters (_PRUNE_REGISTRY_UNSUPPORTED_LAYERS )
106+ def testDoesNotSupportLayer (self , layer ):
107+ self .assertFalse (PruneRegistry .supports (layer ))
110108
111109 def testMakePrunableRaisesErrorForKerasUnsupportedLayer (self ):
112110 with self .assertRaises (ValueError ):
113111 PruneRegistry .make_prunable (layers .ConvLSTM2D (2 , (5 , 5 )))
114112
115113 def testMakePrunableRaisesErrorForCustomLayer (self ):
116114 with self .assertRaises (ValueError ):
117- PruneRegistry .make_prunable (PruneRegistryTest . CustomLayer ())
115+ PruneRegistry .make_prunable (CustomLayer ())
118116
119117 def testMakePrunableRaisesErrorForCustomLayerInheritedFromPrunableLayer (self ):
120118 with self .assertRaises (ValueError ):
121- PruneRegistry .make_prunable (
122- PruneRegistryTest .CustomLayerFromPrunableLayer (10 ))
119+ PruneRegistry .make_prunable (CustomLayerFromPrunableLayer (10 ))
123120
124121 def testMakePrunableWorksOnKerasPrunableLayer (self ):
125122 layer = layers .Dense (10 )
@@ -171,7 +168,7 @@ def testMakePrunableWorksOnKerasRNNLayerWithRNNCellsParams(self):
171168
172169 def testMakePrunableWorksOnKerasRNNLayerWithPrunableCell (self ):
173170 cell1 = layers .LSTMCell (10 )
174- cell2 = PruneRegistryTest . MinimalRNNCellPrunable (5 )
171+ cell2 = MinimalRNNCellPrunable (5 )
175172 layer = layers .RNN ([cell1 , cell2 ])
176173 with self .assertRaises (AttributeError ):
177174 layer .get_prunable_weights ()
@@ -187,12 +184,9 @@ def testMakePrunableWorksOnKerasRNNLayerWithPrunableCell(self):
187184
188185 def testMakePrunableRaisesErrorOnRNNLayersUnsupportedCell (self ):
189186 with self .assertRaises (ValueError ):
190- PruneRegistry .make_prunable (layers .RNN (
191- [layers .LSTMCell (10 ), PruneRegistryTest .MinimalRNNCell (5 )]))
192-
193- def testRescalingLayer (self ):
194- self .assertTrue (PruneRegistry .supports (
195- tf .keras .layers .experimental .preprocessing .Rescaling ))
187+ PruneRegistry .make_prunable (
188+ layers .RNN ([layers .LSTMCell (10 ),
189+ MinimalRNNCell (5 )]))
196190
197191
198192if __name__ == '__main__' :
0 commit comments