14
14
# ==============================================================================
15
15
"""Tests for prune registry."""
16
16
17
+ from absl .testing import parameterized
17
18
import tensorflow as tf
18
19
19
20
from tensorflow_model_optimization .python .core .sparsity .keras import prunable_layer
24
25
PruneRegistry = prune_registry .PruneRegistry
25
26
26
27
27
- class PruneRegistryTest (tf .test .TestCase ):
28
-
29
- class CustomLayer (layers .Layer ):
30
- pass
31
-
32
- class CustomLayerFromPrunableLayer (layers .Dense ):
33
- pass
34
-
35
- class MinimalRNNCell (keras .layers .Layer ):
36
-
37
- def __init__ (self , units , ** kwargs ):
38
- self .units = units
39
- self .state_size = units
40
- super (PruneRegistryTest .MinimalRNNCell , self ).__init__ (** kwargs )
41
-
42
- def build (self , input_shape ):
43
- self .kernel = self .add_weight (shape = (input_shape [- 1 ], self .units ),
44
- initializer = 'uniform' ,
45
- name = 'kernel' )
46
- self .recurrent_kernel = self .add_weight (
47
- shape = (self .units , self .units ),
48
- initializer = 'uniform' ,
49
- name = 'recurrent_kernel' )
50
- self .built = True
51
-
52
- def call (self , inputs , states ):
53
- prev_output = states [0 ]
54
- h = keras .backend .dot (inputs , self .kernel )
55
- output = h + keras .backend .dot (prev_output , self .recurrent_kernel )
56
- return output , [output ]
57
-
58
- class MinimalRNNCellPrunable (MinimalRNNCell , prunable_layer .PrunableLayer ):
59
-
60
- def get_prunable_weights (self ):
61
- return [self .kernel , self .recurrent_kernel ]
62
-
63
- def testSupportsKerasPrunableLayer (self ):
64
- self .assertTrue (PruneRegistry .supports (layers .Dense (10 )))
65
-
66
- def testSupportsKerasPrunableLayerAlias (self ):
67
- # layers.Conv2D maps to layers.convolutional.Conv2D
68
- self .assertTrue (PruneRegistry .supports (layers .Conv2D (10 , 5 )))
69
-
70
- def testSupportsKerasNonPrunableLayer (self ):
71
- # Dropout is a layer known to not be prunable.
72
- self .assertTrue (PruneRegistry .supports (layers .Dropout (0.5 )))
73
-
74
- def testDoesNotSupportKerasUnsupportedLayer (self ):
75
- # ConvLSTM2D is a built-in keras layer but not supported.
76
- self .assertFalse (PruneRegistry .supports (layers .ConvLSTM2D (2 , (5 , 5 ))))
77
-
78
- def testSupportsKerasRNNLayers (self ):
79
- self .assertTrue (PruneRegistry .supports (layers .LSTM (10 )))
80
- self .assertTrue (PruneRegistry .supports (layers .GRU (10 )))
81
- self .assertTrue (PruneRegistry .supports (layers .SimpleRNN (10 )))
82
-
83
- def testSupportsKerasRNNLayerWithRNNCellsParams (self ):
84
- self .assertTrue (PruneRegistry .supports (layers .RNN (layers .LSTMCell (10 ))))
85
-
86
- self .assertTrue (
87
- PruneRegistry .supports (
88
- layers .RNN ([
89
- layers .LSTMCell (10 ),
90
- layers .GRUCell (10 ),
91
- keras .experimental .PeepholeLSTMCell (10 ),
92
- layers .SimpleRNNCell (10 )
93
- ])))
94
-
95
- def testDoesNotSupportKerasRNNLayerUnknownCell (self ):
96
- self .assertFalse (PruneRegistry .supports (
97
- keras .layers .RNN (PruneRegistryTest .MinimalRNNCell (32 ))))
98
-
99
- def testSupportsKerasRNNLayerPrunableCell (self ):
100
- self .assertTrue (PruneRegistry .supports (
101
- keras .layers .RNN (PruneRegistryTest .MinimalRNNCellPrunable (32 ))))
102
-
103
- def testDoesNotSupportCustomLayer (self ):
104
- self .assertFalse (PruneRegistry .supports (PruneRegistryTest .CustomLayer ()))
105
-
106
- def testDoesNotSupportCustomLayerInheritedFromPrunableLayer (self ):
107
- self .assertFalse (
108
- PruneRegistry .supports (
109
- PruneRegistryTest .CustomLayerFromPrunableLayer (10 )))
28
+ class CustomLayer (layers .Layer ):
29
+ pass
30
+
31
+
32
+ class CustomLayerFromPrunableLayer (layers .Dense ):
33
+ pass
34
+
35
+
36
+ class MinimalRNNCell (keras .layers .Layer ):
37
+
38
+ def __init__ (self , units , ** kwargs ):
39
+ self .units = units
40
+ self .state_size = units
41
+ super (MinimalRNNCell , self ).__init__ (** kwargs )
42
+
43
+ def build (self , input_shape ):
44
+ self .kernel = self .add_weight (
45
+ shape = (input_shape [- 1 ], self .units ),
46
+ initializer = 'uniform' ,
47
+ name = 'kernel' )
48
+ self .recurrent_kernel = self .add_weight (
49
+ shape = (self .units , self .units ),
50
+ initializer = 'uniform' ,
51
+ name = 'recurrent_kernel' )
52
+ self .built = True
53
+
54
+ def call (self , inputs , states ):
55
+ prev_output = states [0 ]
56
+ h = keras .backend .dot (inputs , self .kernel )
57
+ output = h + keras .backend .dot (prev_output , self .recurrent_kernel )
58
+ return output , [output ]
59
+
60
+
61
+ class MinimalRNNCellPrunable (MinimalRNNCell , prunable_layer .PrunableLayer ):
62
+
63
+ def get_prunable_weights (self ):
64
+ return [self .kernel , self .recurrent_kernel ]
65
+
66
+
67
+ class PruneRegistryTest (tf .test .TestCase , parameterized .TestCase ):
68
+
69
+ _PRUNE_REGISTRY_SUPPORTED_LAYERS = [
70
+ # Supports basic Keras layers even though it is not prunbale.
71
+ layers .Dense (10 ),
72
+ layers .Conv2D (10 , 5 ),
73
+ layers .Dropout (0.5 ),
74
+ # Supports specific layers from experimental or compat_v1.
75
+ tf .keras .layers .experimental .preprocessing .Rescaling ,
76
+ tf .compat .v1 .keras .layers .BatchNormalization (),
77
+ # Supports Keras RNN Layers with prunable cells.
78
+ layers .LSTM (10 ),
79
+ layers .GRU (10 ),
80
+ layers .SimpleRNN (10 ),
81
+ layers .RNN (layers .LSTMCell (10 )),
82
+ layers .RNN ([
83
+ layers .LSTMCell (10 ),
84
+ layers .GRUCell (10 ),
85
+ keras .experimental .PeepholeLSTMCell (10 ),
86
+ layers .SimpleRNNCell (10 )
87
+ ]),
88
+ keras .layers .RNN (MinimalRNNCellPrunable (32 )),
89
+ ]
90
+
91
+ @parameterized .parameters (_PRUNE_REGISTRY_SUPPORTED_LAYERS )
92
+ def testSupportsLayer (self , layer ):
93
+ self .assertTrue (PruneRegistry .supports (layer ))
94
+
95
+ _PRUNE_REGISTRY_UNSUPPORTED_LAYERS = [
96
+ # Not support a few built-in keras layers.
97
+ layers .ConvLSTM2D (2 , (5 , 5 )),
98
+ # Not support RNN layers with unknown cell
99
+ keras .layers .RNN (MinimalRNNCell (32 )),
100
+ # Not support Custom layers, even though inherited from prunable layer.
101
+ CustomLayer (),
102
+ CustomLayerFromPrunableLayer (10 ),
103
+ ]
104
+
105
+ @parameterized .parameters (_PRUNE_REGISTRY_UNSUPPORTED_LAYERS )
106
+ def testDoesNotSupportLayer (self , layer ):
107
+ self .assertFalse (PruneRegistry .supports (layer ))
110
108
111
109
def testMakePrunableRaisesErrorForKerasUnsupportedLayer (self ):
112
110
with self .assertRaises (ValueError ):
113
111
PruneRegistry .make_prunable (layers .ConvLSTM2D (2 , (5 , 5 )))
114
112
115
113
def testMakePrunableRaisesErrorForCustomLayer (self ):
116
114
with self .assertRaises (ValueError ):
117
- PruneRegistry .make_prunable (PruneRegistryTest . CustomLayer ())
115
+ PruneRegistry .make_prunable (CustomLayer ())
118
116
119
117
def testMakePrunableRaisesErrorForCustomLayerInheritedFromPrunableLayer (self ):
120
118
with self .assertRaises (ValueError ):
121
- PruneRegistry .make_prunable (
122
- PruneRegistryTest .CustomLayerFromPrunableLayer (10 ))
119
+ PruneRegistry .make_prunable (CustomLayerFromPrunableLayer (10 ))
123
120
124
121
def testMakePrunableWorksOnKerasPrunableLayer (self ):
125
122
layer = layers .Dense (10 )
@@ -171,7 +168,7 @@ def testMakePrunableWorksOnKerasRNNLayerWithRNNCellsParams(self):
171
168
172
169
def testMakePrunableWorksOnKerasRNNLayerWithPrunableCell (self ):
173
170
cell1 = layers .LSTMCell (10 )
174
- cell2 = PruneRegistryTest . MinimalRNNCellPrunable (5 )
171
+ cell2 = MinimalRNNCellPrunable (5 )
175
172
layer = layers .RNN ([cell1 , cell2 ])
176
173
with self .assertRaises (AttributeError ):
177
174
layer .get_prunable_weights ()
@@ -187,12 +184,9 @@ def testMakePrunableWorksOnKerasRNNLayerWithPrunableCell(self):
187
184
188
185
def testMakePrunableRaisesErrorOnRNNLayersUnsupportedCell (self ):
189
186
with self .assertRaises (ValueError ):
190
- PruneRegistry .make_prunable (layers .RNN (
191
- [layers .LSTMCell (10 ), PruneRegistryTest .MinimalRNNCell (5 )]))
192
-
193
- def testRescalingLayer (self ):
194
- self .assertTrue (PruneRegistry .supports (
195
- tf .keras .layers .experimental .preprocessing .Rescaling ))
187
+ PruneRegistry .make_prunable (
188
+ layers .RNN ([layers .LSTMCell (10 ),
189
+ MinimalRNNCell (5 )]))
196
190
197
191
198
192
if __name__ == '__main__' :
0 commit comments