Skip to content

Commit 6e7e42d

Browse files
dayeongltensorflower-gardener
authored andcommitted
Add DenseBatchNormReLU transform for non-folded case
PiperOrigin-RevId: 409027616
1 parent 60a2228 commit 6e7e42d

File tree

6 files changed

+233
-41
lines changed

6 files changed

+233
-41
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ py_strict_test(
122122
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
123123
"//tensorflow_model_optimization/python/core/quantization/keras/graph_transformations:model_transformer",
124124
"//tensorflow_model_optimization/python/core/quantization/keras/layers:conv_batchnorm_test_utils",
125+
"//tensorflow_model_optimization/python/core/quantization/keras/layers:dense_batchnorm_test_utils",
125126
],
126127
)
127128

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def apply(self, model, layer_quantize_map):
6565
default_8bit_transforms.ConcatTransform4Inputs(),
6666
default_8bit_transforms.ConcatTransform3Inputs(),
6767
default_8bit_transforms.ConcatTransform(),
68+
default_8bit_transforms.DenseBatchNormQuantize(),
69+
default_8bit_transforms.DenseBatchNormReLUQuantize(),
70+
default_8bit_transforms.DenseBatchNormActivationQuantize(),
6871
default_8bit_transforms.LayerReLUQuantize(),
6972
default_8bit_transforms.LayerReluActivationQuantize(),
7073
]

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,75 @@ def pattern(self):
221221
inputs=[Conv2DReshapeBatchNormQuantize.pattern(self)])
222222

223223

224+
class DenseBatchNormQuantize(transforms.Transform):
225+
"""Ensure FQ does not get placed between Dense and BatchNorm."""
226+
227+
def pattern(self):
228+
return LayerPattern(
229+
'BatchNormalization|SyncBatchNormalization',
230+
inputs=[LayerPattern('Dense', config={'activation': 'linear'})])
231+
232+
def _replace(self, bn_layer_node, dense_layer_node):
233+
if _has_custom_quantize_config(bn_layer_node, dense_layer_node):
234+
return bn_layer_node
235+
236+
dense_layer_node.layer['config']['activation'] = (
237+
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
238+
bn_layer_node.metadata['quantize_config'] = (
239+
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig())
240+
241+
return bn_layer_node
242+
243+
def replacement(self, match_layer):
244+
bn_layer_node = match_layer
245+
dense_layer_node = match_layer.input_layers[0]
246+
247+
return self._replace(bn_layer_node, dense_layer_node)
248+
249+
def custom_objects(self):
250+
return {
251+
'NoOpQuantizeConfig': default_8bit_quantize_configs.NoOpQuantizeConfig,
252+
'NoOpActivation': quantize_aware_activation.NoOpActivation
253+
}
254+
255+
256+
class DenseBatchNormReLUQuantize(DenseBatchNormQuantize):
257+
"""Ensure FQ does not get placed between Dense, BatchNorm and ReLU."""
258+
259+
def pattern(self):
260+
return LayerPattern(
261+
'ReLU', inputs=[super(DenseBatchNormReLUQuantize, self).pattern()])
262+
263+
def _replace(self, relu_layer_node, bn_layer_node, dense_layer_node):
264+
if _has_custom_quantize_config(relu_layer_node, bn_layer_node,
265+
dense_layer_node):
266+
return relu_layer_node
267+
268+
dense_layer_node.layer['config']['activation'] = (
269+
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
270+
bn_layer_node.metadata['quantize_config'] = (
271+
default_8bit_quantize_configs.NoOpQuantizeConfig())
272+
273+
return relu_layer_node
274+
275+
def replacement(self, match_layer):
276+
relu_layer_node = match_layer
277+
bn_layer_node = relu_layer_node.input_layers[0]
278+
dense_layer_node = bn_layer_node.input_layers[0]
279+
280+
return self._replace(relu_layer_node, bn_layer_node, dense_layer_node)
281+
282+
283+
class DenseBatchNormActivationQuantize(DenseBatchNormReLUQuantize):
284+
"""Ensure FQ does not get placed between Dense, BatchNorm and ReLU."""
285+
286+
def pattern(self):
287+
return LayerPattern(
288+
'Activation',
289+
config={'activation': 'relu'},
290+
inputs=[DenseBatchNormQuantize.pattern(self)])
291+
292+
224293
class SeparableConv1DQuantize(transforms.Transform):
225294
"""Add QAT support for Keras SeparableConv1D layer.
226295

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py

Lines changed: 93 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@
2929
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_transforms
3030
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import model_transformer
3131
from tensorflow_model_optimization.python.core.quantization.keras.layers import conv_batchnorm_test_utils
32+
from tensorflow_model_optimization.python.core.quantization.keras.layers import dense_batchnorm_test_utils
3233

3334
ModelTransformer = model_transformer.ModelTransformer
3435

3536
Conv2DModel = conv_batchnorm_test_utils.Conv2DModel
3637
DepthwiseConv2DModel = conv_batchnorm_test_utils.DepthwiseConv2DModel
38+
DenseModel = dense_batchnorm_test_utils.DenseModel
3739

3840
keras = tf.keras
3941

@@ -73,22 +75,27 @@ def _get_model(
7375
post_bn_activation=activation,
7476
squeeze_type=squeeze_type,
7577
normalization_type=normalization_type)
78+
elif layer_type == 'Dense':
79+
return DenseModel.get_nonfolded_batchnorm_model(
80+
post_bn_activation=activation, normalization_type=normalization_type)
7681

7782
def _get_input_shape(self, layer_type):
7883
if layer_type == 'Conv2D':
7984
return Conv2DModel.get_batched_input_shape()
8085
elif layer_type == 'DepthwiseConv2D':
8186
return DepthwiseConv2DModel.get_batched_input_shape()
87+
elif layer_type == 'Dense':
88+
return DenseModel.get_batched_input_shape()
8289

83-
def _test_conv_squeeze_bn_activation_transform(
90+
def _test_conv_squeeze_or_dense_bn_activation_transform(
8491
self,
8592
layer_type,
86-
squeeze_type,
8793
normalization_type,
8894
activation_type,
8995
transform_class,
90-
conv_activation_class,
91-
normalization_quantize_config_class):
96+
conv_or_dense_activation_class,
97+
normalization_quantize_config_class,
98+
squeeze_type=None):
9299
model = self._get_model(layer_type,
93100
squeeze_type,
94101
normalization_type,
@@ -100,14 +107,14 @@ def _test_conv_squeeze_bn_activation_transform(
100107
[transform_class()],
101108
).transform()
102109

103-
conv_layer = transformed_model.layers[1]
110+
conv_or_dense_layer = transformed_model.layers[1]
104111
if squeeze_type == 'sepconv1d_squeeze':
105112
bn_layer = transformed_model.layers[3]
106113
else:
107114
bn_layer = transformed_model.layers[2]
108115

109-
self.assertIsInstance(
110-
conv_layer.activation, conv_activation_class)
116+
self.assertIsInstance(conv_or_dense_layer.activation,
117+
conv_or_dense_activation_class)
111118
self.assertIsInstance(
112119
updated_metadata.get(bn_layer.name).get('quantize_config'),
113120
normalization_quantize_config_class)
@@ -123,15 +130,15 @@ def _test_conv_squeeze_bn_activation_transform(
123130
('DepthwiseConv2D', 'SyncBatchNormalization'),
124131
)
125132
def testConv2DBatchNormQuantize(self, layer_type, normalization_type):
126-
self._test_conv_squeeze_bn_activation_transform(
133+
self._test_conv_squeeze_or_dense_bn_activation_transform(
127134
layer_type=layer_type,
128135
squeeze_type=None,
129136
normalization_type=normalization_type,
130137
activation_type=None,
131138
transform_class=default_8bit_transforms.Conv2DBatchNormQuantize,
132-
conv_activation_class=quantize_aware_activation.NoOpActivation,
133-
normalization_quantize_config_class=
134-
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig)
139+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
140+
normalization_quantize_config_class=default_8bit_quantize_configs
141+
.Default8BitOutputQuantizeConfig)
135142

136143
@parameterized.parameters(
137144
('Conv2D', 'BatchNormalization'),
@@ -140,16 +147,15 @@ def testConv2DBatchNormQuantize(self, layer_type, normalization_type):
140147
('DepthwiseConv2D', 'SyncBatchNormalization'),
141148
)
142149
def testConv2DBatchNormReLUQuantize(self, layer_type, normalization_type):
143-
self._test_conv_squeeze_bn_activation_transform(
150+
self._test_conv_squeeze_or_dense_bn_activation_transform(
144151
layer_type=layer_type,
145152
squeeze_type=None,
146153
normalization_type=normalization_type,
147154
activation_type='relu',
148-
transform_class=
149-
default_8bit_transforms.Conv2DBatchNormReLUQuantize,
150-
conv_activation_class=quantize_aware_activation.NoOpActivation,
151-
normalization_quantize_config_class=
152-
default_8bit_quantize_configs.NoOpQuantizeConfig)
155+
transform_class=default_8bit_transforms.Conv2DBatchNormReLUQuantize,
156+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
157+
normalization_quantize_config_class=default_8bit_quantize_configs
158+
.NoOpQuantizeConfig)
153159

154160
@parameterized.parameters(
155161
('Conv2D', 'BatchNormalization'),
@@ -159,16 +165,16 @@ def testConv2DBatchNormReLUQuantize(self, layer_type, normalization_type):
159165
)
160166
def testConv2DBatchNormActivationQuantize(
161167
self, layer_type, normalization_type):
162-
self._test_conv_squeeze_bn_activation_transform(
168+
self._test_conv_squeeze_or_dense_bn_activation_transform(
163169
layer_type=layer_type,
164170
squeeze_type=None,
165171
normalization_type=normalization_type,
166172
activation_type='act_relu',
167-
transform_class=
168-
default_8bit_transforms.Conv2DBatchNormActivationQuantize,
169-
conv_activation_class=quantize_aware_activation.NoOpActivation,
170-
normalization_quantize_config_class=
171-
default_8bit_quantize_configs.NoOpQuantizeConfig)
173+
transform_class=default_8bit_transforms
174+
.Conv2DBatchNormActivationQuantize,
175+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
176+
normalization_quantize_config_class=default_8bit_quantize_configs
177+
.NoOpQuantizeConfig)
172178

173179
@parameterized.parameters(
174180
('Conv2D', 'BatchNormalization'),
@@ -178,16 +184,15 @@ def testConv2DBatchNormActivationQuantize(
178184
)
179185
def testConv2DReshapeBatchNormQuantize(
180186
self, layer_type, normalization_type):
181-
self._test_conv_squeeze_bn_activation_transform(
187+
self._test_conv_squeeze_or_dense_bn_activation_transform(
182188
layer_type=layer_type,
183189
squeeze_type='sepconv1d_squeeze',
184190
normalization_type=normalization_type,
185191
activation_type=False,
186-
transform_class=
187-
default_8bit_transforms.Conv2DReshapeBatchNormQuantize,
188-
conv_activation_class=quantize_aware_activation.NoOpActivation,
189-
normalization_quantize_config_class=
190-
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig)
192+
transform_class=default_8bit_transforms.Conv2DReshapeBatchNormQuantize,
193+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
194+
normalization_quantize_config_class=default_8bit_quantize_configs
195+
.Default8BitOutputQuantizeConfig)
191196

192197
@parameterized.parameters(
193198
('Conv2D', 'BatchNormalization'),
@@ -197,16 +202,16 @@ def testConv2DReshapeBatchNormQuantize(
197202
)
198203
def testConv2DReshapeBatchNormReLUQuantize(
199204
self, layer_type, normalization_type):
200-
self._test_conv_squeeze_bn_activation_transform(
205+
self._test_conv_squeeze_or_dense_bn_activation_transform(
201206
layer_type=layer_type,
202207
squeeze_type='sepconv1d_squeeze',
203208
normalization_type=normalization_type,
204209
activation_type='relu',
205-
transform_class=
206-
default_8bit_transforms.Conv2DReshapeBatchNormReLUQuantize,
207-
conv_activation_class=quantize_aware_activation.NoOpActivation,
208-
normalization_quantize_config_class=
209-
default_8bit_quantize_configs.NoOpQuantizeConfig)
210+
transform_class=default_8bit_transforms
211+
.Conv2DReshapeBatchNormReLUQuantize,
212+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
213+
normalization_quantize_config_class=default_8bit_quantize_configs
214+
.NoOpQuantizeConfig)
210215

211216
@parameterized.parameters(
212217
('Conv2D', 'BatchNormalization'),
@@ -216,16 +221,63 @@ def testConv2DReshapeBatchNormReLUQuantize(
216221
)
217222
def testConv2DReshapeBatchNormActivationQuantize(
218223
self, layer_type, normalization_type):
219-
self._test_conv_squeeze_bn_activation_transform(
224+
self._test_conv_squeeze_or_dense_bn_activation_transform(
220225
layer_type=layer_type,
221226
squeeze_type='sepconv1d_squeeze',
222227
normalization_type=normalization_type,
223228
activation_type='act_relu',
224-
transform_class=
225-
default_8bit_transforms.Conv2DReshapeBatchNormActivationQuantize,
226-
conv_activation_class=quantize_aware_activation.NoOpActivation,
227-
normalization_quantize_config_class=
228-
default_8bit_quantize_configs.NoOpQuantizeConfig)
229+
transform_class=default_8bit_transforms
230+
.Conv2DReshapeBatchNormActivationQuantize,
231+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
232+
normalization_quantize_config_class=default_8bit_quantize_configs
233+
.NoOpQuantizeConfig)
234+
235+
@parameterized.parameters(
236+
('Dense', 'BatchNormalization'),
237+
('Dense', 'SyncBatchNormalization'),
238+
)
239+
def testDenseBatchNormQuantize(self, layer_type, normalization_type):
240+
self._test_conv_squeeze_or_dense_bn_activation_transform(
241+
layer_type=layer_type,
242+
squeeze_type=None,
243+
normalization_type=normalization_type,
244+
activation_type=None,
245+
transform_class=default_8bit_transforms.DenseBatchNormQuantize,
246+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
247+
normalization_quantize_config_class=default_8bit_quantize_configs
248+
.Default8BitOutputQuantizeConfig)
249+
250+
@parameterized.parameters(
251+
('Dense', 'BatchNormalization'),
252+
('Dense', 'SyncBatchNormalization'),
253+
)
254+
def testDenseBatchNormReLUQuantize(self, layer_type, normalization_type):
255+
self._test_conv_squeeze_or_dense_bn_activation_transform(
256+
layer_type=layer_type,
257+
squeeze_type=None,
258+
normalization_type=normalization_type,
259+
activation_type='relu',
260+
transform_class=default_8bit_transforms.DenseBatchNormReLUQuantize,
261+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
262+
normalization_quantize_config_class=default_8bit_quantize_configs
263+
.NoOpQuantizeConfig)
264+
265+
@parameterized.parameters(
266+
('Dense', 'BatchNormalization'),
267+
('Dense', 'SyncBatchNormalization'),
268+
)
269+
def testDenseBatchNormActivationQuantize(self, layer_type,
270+
normalization_type):
271+
self._test_conv_squeeze_or_dense_bn_activation_transform(
272+
layer_type=layer_type,
273+
squeeze_type=None,
274+
normalization_type=normalization_type,
275+
activation_type='act_relu',
276+
transform_class=default_8bit_transforms
277+
.DenseBatchNormActivationQuantize,
278+
conv_or_dense_activation_class=quantize_aware_activation.NoOpActivation,
279+
normalization_quantize_config_class=default_8bit_quantize_configs
280+
.NoOpQuantizeConfig)
229281

230282
@parameterized.named_parameters(
231283
('padding_valid', {'padding': 'valid'}),

tensorflow_model_optimization/python/core/quantization/keras/layers/BUILD

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,12 @@ py_strict_library(
2323
# tensorflow dep1,
2424
],
2525
)
26+
27+
py_strict_library(
28+
name = "dense_batchnorm_test_utils",
29+
srcs = ["dense_batchnorm_test_utils.py"],
30+
srcs_version = "PY3",
31+
deps = [
32+
# tensorflow dep1,
33+
],
34+
)

0 commit comments

Comments
 (0)