Skip to content

Commit 4924fd3

Browse files
daverimtensorflower-gardener
authored andcommitted
Add DenseBatchNorm.* transforms to default_n_bit scheme
PiperOrigin-RevId: 428949606
1 parent 1455e6e commit 4924fd3

File tree

4 files changed

+168
-15
lines changed

4 files changed

+168
-15
lines changed

tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ py_strict_test(
134134
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
135135
"//tensorflow_model_optimization/python/core/quantization/keras/graph_transformations:model_transformer",
136136
"//tensorflow_model_optimization/python/core/quantization/keras/layers:conv_batchnorm_test_utils",
137+
"//tensorflow_model_optimization/python/core/quantization/keras/layers:dense_batchnorm_test_utils",
137138
],
138139
)
139140

tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_layout_transform.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,15 @@ def apply(self, model, layer_quantize_map):
104104
default_n_bit_transforms.LayerReluActivationQuantize(
105105
num_bits_weight=self._num_bits_weight,
106106
num_bits_activation=self._num_bits_activation),
107+
default_n_bit_transforms.DenseBatchNormQuantize(
108+
num_bits_weight=self._num_bits_weight,
109+
num_bits_activation=self._num_bits_activation),
110+
default_n_bit_transforms.DenseBatchNormReLUQuantize(
111+
num_bits_weight=self._num_bits_weight,
112+
num_bits_activation=self._num_bits_activation),
113+
default_n_bit_transforms.DenseBatchNormActivationQuantize(
114+
num_bits_weight=self._num_bits_weight,
115+
num_bits_activation=self._num_bits_activation),
107116
]
108117
return model_transformer.ModelTransformer(
109118
model, transforms,

tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,95 @@ def pattern(self):
263263
inputs=[Conv2DReshapeBatchNormQuantize.pattern(self)])
264264

265265

266+
class DenseBatchNormQuantize(transforms.Transform):
267+
"""Transform to be applied to "Dense"+ "BatchNorm" Graph.
268+
269+
This transform disables Quantization between Dense and BatchNorm
270+
to ensure FQ does not get placed between them.
271+
"""
272+
273+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
274+
self._num_bits_weight = num_bits_weight
275+
self._num_bits_activation = num_bits_activation
276+
277+
def pattern(self):
278+
return LayerPattern(
279+
'BatchNormalization|SyncBatchNormalization',
280+
inputs=[LayerPattern('Dense', config={'activation': 'linear'})])
281+
282+
def _replace(self, bn_layer_node, dense_layer_node):
283+
if _has_custom_quantize_config(bn_layer_node, dense_layer_node):
284+
return bn_layer_node
285+
286+
dense_layer_node.layer['config']['activation'] = (
287+
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
288+
bn_layer_node.metadata['quantize_config'] = (
289+
configs.DefaultNBitOutputQuantizeConfig(
290+
num_bits_weight=self._num_bits_weight,
291+
num_bits_activation=self._num_bits_activation))
292+
return bn_layer_node
293+
294+
def replacement(self, match_layer):
295+
bn_layer_node = match_layer
296+
dense_layer_node = match_layer.input_layers[0]
297+
298+
return self._replace(bn_layer_node, dense_layer_node)
299+
300+
def custom_objects(self):
301+
return {
302+
'DefaultNBitOutputQuantizeConfig':
303+
configs.DefaultNBitOutputQuantizeConfig,
304+
'NoOpQuantizeConfig':
305+
configs.NoOpQuantizeConfig,
306+
'NoOpActivation': quantize_aware_activation.NoOpActivation
307+
}
308+
309+
310+
class DenseBatchNormReLUQuantize(DenseBatchNormQuantize):
311+
"""Transform to be applied to "Dense"+ "BatchNorm" + "ReLU" Graph.
312+
313+
This transform disables Quantization between Dense, BatchNorm and ReLU
314+
to ensure FQ does not get placed between them.
315+
"""
316+
317+
def pattern(self):
318+
return LayerPattern(
319+
'ReLU', inputs=[super(DenseBatchNormReLUQuantize, self).pattern()])
320+
321+
def _replace(self, relu_layer_node, bn_layer_node, dense_layer_node):
322+
if _has_custom_quantize_config(relu_layer_node, bn_layer_node,
323+
dense_layer_node):
324+
return relu_layer_node
325+
326+
dense_layer_node.layer['config']['activation'] = (
327+
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
328+
bn_layer_node.metadata['quantize_config'] = (
329+
configs.NoOpQuantizeConfig())
330+
331+
return relu_layer_node
332+
333+
def replacement(self, match_layer):
334+
relu_layer_node = match_layer
335+
bn_layer_node = relu_layer_node.input_layers[0]
336+
dense_layer_node = bn_layer_node.input_layers[0]
337+
338+
return self._replace(relu_layer_node, bn_layer_node, dense_layer_node)
339+
340+
341+
class DenseBatchNormActivationQuantize(DenseBatchNormReLUQuantize):
342+
"""Transform to be applied to "Dense"+ "BatchNorm" + "ReLU" Graph.
343+
344+
This transform disables Quantization between Dense, BatchNorm and ReLU
345+
to ensure FQ does not get placed between them.
346+
"""
347+
348+
def pattern(self):
349+
return LayerPattern(
350+
'Activation',
351+
config={'activation': 'relu'},
352+
inputs=[DenseBatchNormQuantize.pattern(self)])
353+
354+
266355
class SeparableConv1DQuantize(transforms.Transform):
267356
"""Add QAT support for Keras SeparableConv1D layer.
268357

tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms_test.py

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@
2929
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_transforms
3030
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import model_transformer
3131
from tensorflow_model_optimization.python.core.quantization.keras.layers import conv_batchnorm_test_utils
32+
from tensorflow_model_optimization.python.core.quantization.keras.layers import dense_batchnorm_test_utils
3233

3334
ModelTransformer = model_transformer.ModelTransformer
3435

3536
Conv2DModel = conv_batchnorm_test_utils.Conv2DModel
3637
DepthwiseConv2DModel = conv_batchnorm_test_utils.DepthwiseConv2DModel
38+
DenseModel = dense_batchnorm_test_utils.DenseModel
3739

3840
keras = tf.keras
3941

@@ -73,21 +75,26 @@ def _get_model(
7375
post_bn_activation=activation,
7476
squeeze_type=squeeze_type,
7577
normalization_type=normalization_type)
78+
elif layer_type == 'Dense':
79+
return DenseModel.get_nonfolded_batchnorm_model(
80+
post_bn_activation=activation, normalization_type=normalization_type)
7681

7782
def _get_input_shape(self, layer_type):
7883
if layer_type == 'Conv2D':
7984
return Conv2DModel.get_batched_input_shape()
8085
elif layer_type == 'DepthwiseConv2D':
8186
return DepthwiseConv2DModel.get_batched_input_shape()
87+
elif layer_type == 'Dense':
88+
return DenseModel.get_batched_input_shape()
8289

83-
def _test_conv_squeeze_bn_activation_transform(
90+
def _test_conv_squeeze_or_dense_bn_activation_transform(
8491
self,
8592
layer_type,
8693
squeeze_type,
8794
normalization_type,
8895
activation_type,
8996
transform_class,
90-
conv_activation_class,
97+
conv_or_dense_activation_class,
9198
normalization_quantize_config_class):
9299
model = self._get_model(layer_type,
93100
squeeze_type,
@@ -107,7 +114,7 @@ def _test_conv_squeeze_bn_activation_transform(
107114
bn_layer = transformed_model.layers[2]
108115

109116
self.assertIsInstance(
110-
conv_layer.activation, conv_activation_class)
117+
conv_layer.activation, conv_or_dense_activation_class)
111118
self.assertIsInstance(
112119
updated_metadata.get(bn_layer.name).get('quantize_config'),
113120
normalization_quantize_config_class)
@@ -123,13 +130,13 @@ def _test_conv_squeeze_bn_activation_transform(
123130
('DepthwiseConv2D', 'SyncBatchNormalization'),
124131
)
125132
def testConv2DBatchNormQuantize(self, layer_type, normalization_type):
126-
self._test_conv_squeeze_bn_activation_transform(
133+
self._test_conv_squeeze_or_dense_bn_activation_transform(
127134
layer_type=layer_type,
128135
squeeze_type=None,
129136
normalization_type=normalization_type,
130137
activation_type=None,
131138
transform_class=default_n_bit_transforms.Conv2DBatchNormQuantize,
132-
conv_activation_class=quantize_aware_activation.NoOpActivation,
139+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
133140
normalization_quantize_config_class=
134141
n_bit_configs.DefaultNBitOutputQuantizeConfig)
135142

@@ -140,14 +147,14 @@ def testConv2DBatchNormQuantize(self, layer_type, normalization_type):
140147
('DepthwiseConv2D', 'SyncBatchNormalization'),
141148
)
142149
def testConv2DBatchNormReLUQuantize(self, layer_type, normalization_type):
143-
self._test_conv_squeeze_bn_activation_transform(
150+
self._test_conv_squeeze_or_dense_bn_activation_transform(
144151
layer_type=layer_type,
145152
squeeze_type=None,
146153
normalization_type=normalization_type,
147154
activation_type='relu',
148155
transform_class=
149156
default_n_bit_transforms.Conv2DBatchNormReLUQuantize,
150-
conv_activation_class=quantize_aware_activation.NoOpActivation,
157+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
151158
normalization_quantize_config_class=
152159
n_bit_configs.NoOpQuantizeConfig)
153160

@@ -159,14 +166,14 @@ def testConv2DBatchNormReLUQuantize(self, layer_type, normalization_type):
159166
)
160167
def testConv2DBatchNormActivationQuantize(
161168
self, layer_type, normalization_type):
162-
self._test_conv_squeeze_bn_activation_transform(
169+
self._test_conv_squeeze_or_dense_bn_activation_transform(
163170
layer_type=layer_type,
164171
squeeze_type=None,
165172
normalization_type=normalization_type,
166173
activation_type='act_relu',
167174
transform_class=
168175
default_n_bit_transforms.Conv2DBatchNormActivationQuantize,
169-
conv_activation_class=quantize_aware_activation.NoOpActivation,
176+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
170177
normalization_quantize_config_class=
171178
n_bit_configs.NoOpQuantizeConfig)
172179

@@ -178,14 +185,14 @@ def testConv2DBatchNormActivationQuantize(
178185
)
179186
def testConv2DReshapeBatchNormQuantize(
180187
self, layer_type, normalization_type):
181-
self._test_conv_squeeze_bn_activation_transform(
188+
self._test_conv_squeeze_or_dense_bn_activation_transform(
182189
layer_type=layer_type,
183190
squeeze_type='sepconv1d_squeeze',
184191
normalization_type=normalization_type,
185192
activation_type=False,
186193
transform_class=
187194
default_n_bit_transforms.Conv2DReshapeBatchNormQuantize,
188-
conv_activation_class=quantize_aware_activation.NoOpActivation,
195+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
189196
normalization_quantize_config_class=
190197
n_bit_configs.DefaultNBitOutputQuantizeConfig)
191198

@@ -197,14 +204,14 @@ def testConv2DReshapeBatchNormQuantize(
197204
)
198205
def testConv2DReshapeBatchNormReLUQuantize(
199206
self, layer_type, normalization_type):
200-
self._test_conv_squeeze_bn_activation_transform(
207+
self._test_conv_squeeze_or_dense_bn_activation_transform(
201208
layer_type=layer_type,
202209
squeeze_type='sepconv1d_squeeze',
203210
normalization_type=normalization_type,
204211
activation_type='relu',
205212
transform_class=
206213
default_n_bit_transforms.Conv2DReshapeBatchNormReLUQuantize,
207-
conv_activation_class=quantize_aware_activation.NoOpActivation,
214+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
208215
normalization_quantize_config_class=
209216
n_bit_configs.NoOpQuantizeConfig)
210217

@@ -216,17 +223,64 @@ def testConv2DReshapeBatchNormReLUQuantize(
216223
)
217224
def testConv2DReshapeBatchNormActivationQuantize(
218225
self, layer_type, normalization_type):
219-
self._test_conv_squeeze_bn_activation_transform(
226+
self._test_conv_squeeze_or_dense_bn_activation_transform(
220227
layer_type=layer_type,
221228
squeeze_type='sepconv1d_squeeze',
222229
normalization_type=normalization_type,
223230
activation_type='act_relu',
224231
transform_class=
225232
default_n_bit_transforms.Conv2DReshapeBatchNormActivationQuantize,
226-
conv_activation_class=quantize_aware_activation.NoOpActivation,
233+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
227234
normalization_quantize_config_class=
228235
n_bit_configs.NoOpQuantizeConfig)
229236

237+
@parameterized.parameters(
238+
('Dense', 'BatchNormalization'),
239+
('Dense', 'SyncBatchNormalization'),
240+
)
241+
def testDenseBatchNormQuantize(self, layer_type, normalization_type):
242+
self._test_conv_squeeze_or_dense_bn_activation_transform(
243+
layer_type=layer_type,
244+
squeeze_type=None,
245+
normalization_type=normalization_type,
246+
activation_type=None,
247+
transform_class=default_n_bit_transforms.DenseBatchNormQuantize,
248+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
249+
normalization_quantize_config_class=n_bit_configs
250+
.DefaultNBitOutputQuantizeConfig)
251+
252+
@parameterized.parameters(
253+
('Dense', 'BatchNormalization'),
254+
('Dense', 'SyncBatchNormalization'),
255+
)
256+
def testDenseBatchNormReLUQuantize(self, layer_type, normalization_type):
257+
self._test_conv_squeeze_or_dense_bn_activation_transform(
258+
layer_type=layer_type,
259+
squeeze_type=None,
260+
normalization_type=normalization_type,
261+
activation_type='relu',
262+
transform_class=default_n_bit_transforms.DenseBatchNormReLUQuantize,
263+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
264+
normalization_quantize_config_class=n_bit_configs
265+
.NoOpQuantizeConfig)
266+
267+
@parameterized.parameters(
268+
('Dense', 'BatchNormalization'),
269+
('Dense', 'SyncBatchNormalization'),
270+
)
271+
def testDenseBatchNormActivationQuantize(self, layer_type,
272+
normalization_type):
273+
self._test_conv_squeeze_or_dense_bn_activation_transform(
274+
layer_type=layer_type,
275+
squeeze_type=None,
276+
normalization_type=normalization_type,
277+
activation_type='act_relu',
278+
transform_class=default_n_bit_transforms
279+
.DenseBatchNormActivationQuantize,
280+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
281+
normalization_quantize_config_class=n_bit_configs
282+
.NoOpQuantizeConfig)
283+
230284
@parameterized.named_parameters(
231285
('padding_valid', {'padding': 'valid'}),
232286
('padding_same', {'padding': 'same'}),

0 commit comments

Comments
 (0)