Skip to content

Commit e12b3b3

Browse files
rino20tensorflower-gardener
authored andcommitted
Add compat_v1 batchnorm to PruneRegistry
PiperOrigin-RevId: 419516112
1 parent 0e56d30 commit e12b3b3

File tree

3 files changed

+93
-97
lines changed

3 files changed

+93
-97
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ py_strict_test(
263263
deps = [
264264
":prunable_layer",
265265
":prune_registry",
266+
# absl/testing:parameterized dep1,
266267
# tensorflow dep1,
267268
],
268269
)

tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
2222

2323
layers = tf.keras.layers
24+
layers_compat_v1 = tf.compat.v1.keras.layers
2425

2526

2627
class PruneRegistry(object):
@@ -94,12 +95,12 @@ class PruneRegistry(object):
9495
layers.MaxPooling2D: [],
9596
layers.MaxPooling3D: [],
9697
layers.MultiHeadAttention: [
97-
'_query_dense.kernel',
98-
'_key_dense.kernel',
99-
'_value_dense.kernel',
100-
'_output_dense.kernel'],
98+
'_query_dense.kernel', '_key_dense.kernel', '_value_dense.kernel',
99+
'_output_dense.kernel'
100+
],
101101
layers.experimental.preprocessing.Rescaling.__class__: [],
102102
TensorFlowOpLayer: [],
103+
layers_compat_v1.BatchNormalization: [],
103104
}
104105

105106
_RNN_CELLS_WEIGHTS_MAP = {

tensorflow_model_optimization/python/core/sparsity/keras/prune_registry_test.py

Lines changed: 87 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ==============================================================================
1515
"""Tests for prune registry."""
1616

17+
from absl.testing import parameterized
1718
import tensorflow as tf
1819

1920
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
@@ -24,102 +25,98 @@
2425
PruneRegistry = prune_registry.PruneRegistry
2526

2627

27-
class PruneRegistryTest(tf.test.TestCase):
28-
29-
class CustomLayer(layers.Layer):
30-
pass
31-
32-
class CustomLayerFromPrunableLayer(layers.Dense):
33-
pass
34-
35-
class MinimalRNNCell(keras.layers.Layer):
36-
37-
def __init__(self, units, **kwargs):
38-
self.units = units
39-
self.state_size = units
40-
super(PruneRegistryTest.MinimalRNNCell, self).__init__(**kwargs)
41-
42-
def build(self, input_shape):
43-
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
44-
initializer='uniform',
45-
name='kernel')
46-
self.recurrent_kernel = self.add_weight(
47-
shape=(self.units, self.units),
48-
initializer='uniform',
49-
name='recurrent_kernel')
50-
self.built = True
51-
52-
def call(self, inputs, states):
53-
prev_output = states[0]
54-
h = keras.backend.dot(inputs, self.kernel)
55-
output = h + keras.backend.dot(prev_output, self.recurrent_kernel)
56-
return output, [output]
57-
58-
class MinimalRNNCellPrunable(MinimalRNNCell, prunable_layer.PrunableLayer):
59-
60-
def get_prunable_weights(self):
61-
return [self.kernel, self.recurrent_kernel]
62-
63-
def testSupportsKerasPrunableLayer(self):
64-
self.assertTrue(PruneRegistry.supports(layers.Dense(10)))
65-
66-
def testSupportsKerasPrunableLayerAlias(self):
67-
# layers.Conv2D maps to layers.convolutional.Conv2D
68-
self.assertTrue(PruneRegistry.supports(layers.Conv2D(10, 5)))
69-
70-
def testSupportsKerasNonPrunableLayer(self):
71-
# Dropout is a layer known to not be prunable.
72-
self.assertTrue(PruneRegistry.supports(layers.Dropout(0.5)))
73-
74-
def testDoesNotSupportKerasUnsupportedLayer(self):
75-
# ConvLSTM2D is a built-in keras layer but not supported.
76-
self.assertFalse(PruneRegistry.supports(layers.ConvLSTM2D(2, (5, 5))))
77-
78-
def testSupportsKerasRNNLayers(self):
79-
self.assertTrue(PruneRegistry.supports(layers.LSTM(10)))
80-
self.assertTrue(PruneRegistry.supports(layers.GRU(10)))
81-
self.assertTrue(PruneRegistry.supports(layers.SimpleRNN(10)))
82-
83-
def testSupportsKerasRNNLayerWithRNNCellsParams(self):
84-
self.assertTrue(PruneRegistry.supports(layers.RNN(layers.LSTMCell(10))))
85-
86-
self.assertTrue(
87-
PruneRegistry.supports(
88-
layers.RNN([
89-
layers.LSTMCell(10),
90-
layers.GRUCell(10),
91-
keras.experimental.PeepholeLSTMCell(10),
92-
layers.SimpleRNNCell(10)
93-
])))
94-
95-
def testDoesNotSupportKerasRNNLayerUnknownCell(self):
96-
self.assertFalse(PruneRegistry.supports(
97-
keras.layers.RNN(PruneRegistryTest.MinimalRNNCell(32))))
98-
99-
def testSupportsKerasRNNLayerPrunableCell(self):
100-
self.assertTrue(PruneRegistry.supports(
101-
keras.layers.RNN(PruneRegistryTest.MinimalRNNCellPrunable(32))))
102-
103-
def testDoesNotSupportCustomLayer(self):
104-
self.assertFalse(PruneRegistry.supports(PruneRegistryTest.CustomLayer()))
105-
106-
def testDoesNotSupportCustomLayerInheritedFromPrunableLayer(self):
107-
self.assertFalse(
108-
PruneRegistry.supports(
109-
PruneRegistryTest.CustomLayerFromPrunableLayer(10)))
28+
class CustomLayer(layers.Layer):
29+
pass
30+
31+
32+
class CustomLayerFromPrunableLayer(layers.Dense):
33+
pass
34+
35+
36+
class MinimalRNNCell(keras.layers.Layer):
37+
38+
def __init__(self, units, **kwargs):
39+
self.units = units
40+
self.state_size = units
41+
super(MinimalRNNCell, self).__init__(**kwargs)
42+
43+
def build(self, input_shape):
44+
self.kernel = self.add_weight(
45+
shape=(input_shape[-1], self.units),
46+
initializer='uniform',
47+
name='kernel')
48+
self.recurrent_kernel = self.add_weight(
49+
shape=(self.units, self.units),
50+
initializer='uniform',
51+
name='recurrent_kernel')
52+
self.built = True
53+
54+
def call(self, inputs, states):
55+
prev_output = states[0]
56+
h = keras.backend.dot(inputs, self.kernel)
57+
output = h + keras.backend.dot(prev_output, self.recurrent_kernel)
58+
return output, [output]
59+
60+
61+
class MinimalRNNCellPrunable(MinimalRNNCell, prunable_layer.PrunableLayer):
62+
63+
def get_prunable_weights(self):
64+
return [self.kernel, self.recurrent_kernel]
65+
66+
67+
class PruneRegistryTest(tf.test.TestCase, parameterized.TestCase):
68+
69+
_PRUNE_REGISTRY_SUPPORTED_LAYERS = [
70+
# Supports basic Keras layers even though it is not prunbale.
71+
layers.Dense(10),
72+
layers.Conv2D(10, 5),
73+
layers.Dropout(0.5),
74+
# Supports specific layers from experimental or compat_v1.
75+
tf.keras.layers.experimental.preprocessing.Rescaling,
76+
tf.compat.v1.keras.layers.BatchNormalization(),
77+
# Supports Keras RNN Layers with prunable cells.
78+
layers.LSTM(10),
79+
layers.GRU(10),
80+
layers.SimpleRNN(10),
81+
layers.RNN(layers.LSTMCell(10)),
82+
layers.RNN([
83+
layers.LSTMCell(10),
84+
layers.GRUCell(10),
85+
keras.experimental.PeepholeLSTMCell(10),
86+
layers.SimpleRNNCell(10)
87+
]),
88+
keras.layers.RNN(MinimalRNNCellPrunable(32)),
89+
]
90+
91+
@parameterized.parameters(_PRUNE_REGISTRY_SUPPORTED_LAYERS)
92+
def testSupportsLayer(self, layer):
93+
self.assertTrue(PruneRegistry.supports(layer))
94+
95+
_PRUNE_REGISTRY_UNSUPPORTED_LAYERS = [
96+
# Not support a few built-in keras layers.
97+
layers.ConvLSTM2D(2, (5, 5)),
98+
# Not support RNN layers with unknown cell
99+
keras.layers.RNN(MinimalRNNCell(32)),
100+
# Not support Custom layers, even though inherited from prunable layer.
101+
CustomLayer(),
102+
CustomLayerFromPrunableLayer(10),
103+
]
104+
105+
@parameterized.parameters(_PRUNE_REGISTRY_UNSUPPORTED_LAYERS)
106+
def testDoesNotSupportLayer(self, layer):
107+
self.assertFalse(PruneRegistry.supports(layer))
110108

111109
def testMakePrunableRaisesErrorForKerasUnsupportedLayer(self):
112110
with self.assertRaises(ValueError):
113111
PruneRegistry.make_prunable(layers.ConvLSTM2D(2, (5, 5)))
114112

115113
def testMakePrunableRaisesErrorForCustomLayer(self):
116114
with self.assertRaises(ValueError):
117-
PruneRegistry.make_prunable(PruneRegistryTest.CustomLayer())
115+
PruneRegistry.make_prunable(CustomLayer())
118116

119117
def testMakePrunableRaisesErrorForCustomLayerInheritedFromPrunableLayer(self):
120118
with self.assertRaises(ValueError):
121-
PruneRegistry.make_prunable(
122-
PruneRegistryTest.CustomLayerFromPrunableLayer(10))
119+
PruneRegistry.make_prunable(CustomLayerFromPrunableLayer(10))
123120

124121
def testMakePrunableWorksOnKerasPrunableLayer(self):
125122
layer = layers.Dense(10)
@@ -171,7 +168,7 @@ def testMakePrunableWorksOnKerasRNNLayerWithRNNCellsParams(self):
171168

172169
def testMakePrunableWorksOnKerasRNNLayerWithPrunableCell(self):
173170
cell1 = layers.LSTMCell(10)
174-
cell2 = PruneRegistryTest.MinimalRNNCellPrunable(5)
171+
cell2 = MinimalRNNCellPrunable(5)
175172
layer = layers.RNN([cell1, cell2])
176173
with self.assertRaises(AttributeError):
177174
layer.get_prunable_weights()
@@ -187,12 +184,9 @@ def testMakePrunableWorksOnKerasRNNLayerWithPrunableCell(self):
187184

188185
def testMakePrunableRaisesErrorOnRNNLayersUnsupportedCell(self):
189186
with self.assertRaises(ValueError):
190-
PruneRegistry.make_prunable(layers.RNN(
191-
[layers.LSTMCell(10), PruneRegistryTest.MinimalRNNCell(5)]))
192-
193-
def testRescalingLayer(self):
194-
self.assertTrue(PruneRegistry.supports(
195-
tf.keras.layers.experimental.preprocessing.Rescaling))
187+
PruneRegistry.make_prunable(
188+
layers.RNN([layers.LSTMCell(10),
189+
MinimalRNNCell(5)]))
196190

197191

198192
if __name__ == '__main__':

0 commit comments

Comments
 (0)