Skip to content

Commit d33ef45

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
SeparableConv1D matches with Squeeze layer
SepConv1D generates squeeze layers in between. Due to this matches require additional patterns to account for the squeeze layer in between. PiperOrigin-RevId: 324901553
1 parent 96e0025 commit d33ef45

File tree

2 files changed

+70
-17
lines changed

2 files changed

+70
-17
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -177,18 +177,22 @@ def pattern(self):
177177
inputs=[LayerPattern(
178178
'Conv2D|DepthwiseConv2D', config={'activation': 'linear'})])
179179

180-
def replacement(self, match_layer):
181-
bn_layer_node, conv_layer_node = match_layer, match_layer.input_layers[0]
182-
180+
def _replace(self, bn_layer_node, conv_layer_node):
183181
if _has_custom_quantize_config(bn_layer_node, conv_layer_node):
184-
return match_layer
182+
return bn_layer_node
185183

186184
conv_layer_node.layer['config']['activation'] = \
187185
keras.activations.serialize(quantize_aware_activation.NoOpActivation())
188186
bn_layer_node.metadata['quantize_config'] = \
189187
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig()
190188

191-
return match_layer
189+
return bn_layer_node
190+
191+
def replacement(self, match_layer):
192+
bn_layer_node = match_layer
193+
conv_layer_node = match_layer.input_layers[0]
194+
195+
return self._replace(bn_layer_node, conv_layer_node)
192196

193197
def custom_objects(self):
194198
return {
@@ -197,6 +201,26 @@ def custom_objects(self):
197201
}
198202

199203

204+
class Conv2DReshapeBatchNormQuantize(Conv2DBatchNormQuantize):
205+
"""Ensure FQ does not get placed between Conv, Reshape and BatchNorm."""
206+
207+
def pattern(self):
208+
return LayerPattern(
209+
'BatchNormalization',
210+
inputs=[LayerPattern(
211+
'Lambda', config={'name': 'sepconv1d_squeeze.*'},
212+
inputs=[LayerPattern(
213+
'Conv2D|DepthwiseConv2D',
214+
config={'activation': 'linear'})])])
215+
216+
def replacement(self, match_layer):
217+
bn_layer_node = match_layer
218+
reshape_layer_node = bn_layer_node.input_layers[0]
219+
conv_layer_node = reshape_layer_node.input_layers[0]
220+
221+
return self._replace(bn_layer_node, conv_layer_node)
222+
223+
200224
class Conv2DBatchNormReLUQuantize(Conv2DBatchNormQuantize):
201225
"""Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
202226

@@ -206,27 +230,24 @@ def pattern(self):
206230
'ReLU',
207231
inputs=[super(Conv2DBatchNormReLUQuantize, self).pattern()])
208232

209-
def replacement(self, match_layer):
210-
relu_layer_node = match_layer
211-
bn_layer_node = relu_layer_node.input_layers[0]
212-
conv_layer_node = bn_layer_node.input_layers[0]
213-
233+
def _replace(self, relu_layer_node, bn_layer_node, conv_layer_node):
214234
if _has_custom_quantize_config(
215235
relu_layer_node, bn_layer_node, conv_layer_node):
216-
return match_layer
236+
return relu_layer_node
217237

218238
conv_layer_node.layer['config']['activation'] = \
219239
keras.activations.serialize(quantize_aware_activation.NoOpActivation())
220240
bn_layer_node.metadata['quantize_config'] = \
221241
default_8bit_quantize_configs.NoOpQuantizeConfig()
222242

223-
return match_layer
243+
return relu_layer_node
224244

225-
def custom_objects(self):
226-
return {
227-
'NoOpQuantizeConfig': default_8bit_quantize_configs.NoOpQuantizeConfig,
228-
'NoOpActivation': quantize_aware_activation.NoOpActivation
229-
}
245+
def replacement(self, match_layer):
246+
relu_layer_node = match_layer
247+
bn_layer_node = relu_layer_node.input_layers[0]
248+
conv_layer_node = bn_layer_node.input_layers[0]
249+
250+
return self._replace(relu_layer_node, bn_layer_node, conv_layer_node)
230251

231252

232253
class Conv2DBatchNormActivationQuantize(Conv2DBatchNormReLUQuantize):
@@ -239,6 +260,34 @@ def pattern(self):
239260
inputs=[Conv2DBatchNormQuantize.pattern(self)])
240261

241262

263+
class Conv2DReshapeBatchNormReLUQuantize(Conv2DBatchNormReLUQuantize):
264+
"""Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
265+
266+
def pattern(self):
267+
return LayerPattern(
268+
'ReLU',
269+
inputs=[Conv2DReshapeBatchNormQuantize.pattern(self)])
270+
271+
def replacement(self, match_layer):
272+
relu_layer_node = match_layer
273+
bn_layer_node = relu_layer_node.input_layers[0]
274+
squeeze_layer_node = bn_layer_node.input_layers[0]
275+
conv_layer_node = squeeze_layer_node.input_layers[0]
276+
277+
return self._replace(relu_layer_node, bn_layer_node, conv_layer_node)
278+
279+
280+
class Conv2DReshapeBatchNormActivationQuantize(
281+
Conv2DReshapeBatchNormReLUQuantize):
282+
"""Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
283+
284+
def pattern(self):
285+
return LayerPattern(
286+
'Activation',
287+
config={'activation': 'relu'},
288+
inputs=[Conv2DReshapeBatchNormQuantize.pattern(self)])
289+
290+
242291
class SeparableConv1DQuantize(transforms.Transform):
243292
"""Add QAT support for Keras SeparableConv1D layer.
244293

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,10 @@ def testSeparableConvQuantize_(self, kwargs):
337337
self.assertAllClose(sepconv_model.predict(x), transformed_model.predict(x),
338338
atol=1e-5, rtol=1e-5)
339339

340+
# TODO(pulkitb): Add individual tests for the following transforms.
341+
# Conv2DReshapeBatchNormQuantize, Conv2DReshapeBatchNormReLUQuantize
342+
# Conv2DReshapeBatchNormActivationQuantize
343+
340344
@parameterized.parameters(
341345
('relu', default_8bit_transforms.AddReLUQuantize),
342346
('act_relu', default_8bit_transforms.AddActivationQuantize),

0 commit comments

Comments
 (0)