@@ -177,18 +177,22 @@ def pattern(self):
177177 inputs = [LayerPattern (
178178 'Conv2D|DepthwiseConv2D' , config = {'activation' : 'linear' })])
179179
180- def replacement (self , match_layer ):
181- bn_layer_node , conv_layer_node = match_layer , match_layer .input_layers [0 ]
182-
180+ def _replace (self , bn_layer_node , conv_layer_node ):
183181 if _has_custom_quantize_config (bn_layer_node , conv_layer_node ):
184- return match_layer
182+ return bn_layer_node
185183
186184 conv_layer_node .layer ['config' ]['activation' ] = \
187185 keras .activations .serialize (quantize_aware_activation .NoOpActivation ())
188186 bn_layer_node .metadata ['quantize_config' ] = \
189187 default_8bit_quantize_configs .Default8BitOutputQuantizeConfig ()
190188
191- return match_layer
189+ return bn_layer_node
190+
191+ def replacement (self , match_layer ):
192+ bn_layer_node = match_layer
193+ conv_layer_node = match_layer .input_layers [0 ]
194+
195+ return self ._replace (bn_layer_node , conv_layer_node )
192196
193197 def custom_objects (self ):
194198 return {
@@ -197,6 +201,26 @@ def custom_objects(self):
197201 }
198202
199203
204+ class Conv2DReshapeBatchNormQuantize (Conv2DBatchNormQuantize ):
205+ """Ensure FQ does not get placed between Conv, Reshape and BatchNorm."""
206+
207+ def pattern (self ):
208+ return LayerPattern (
209+ 'BatchNormalization' ,
210+ inputs = [LayerPattern (
211+ 'Lambda' , config = {'name' : 'sepconv1d_squeeze.*' },
212+ inputs = [LayerPattern (
213+ 'Conv2D|DepthwiseConv2D' ,
214+ config = {'activation' : 'linear' })])])
215+
216+ def replacement (self , match_layer ):
217+ bn_layer_node = match_layer
218+ reshape_layer_node = bn_layer_node .input_layers [0 ]
219+ conv_layer_node = reshape_layer_node .input_layers [0 ]
220+
221+ return self ._replace (bn_layer_node , conv_layer_node )
222+
223+
200224class Conv2DBatchNormReLUQuantize (Conv2DBatchNormQuantize ):
201225 """Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
202226
@@ -206,27 +230,24 @@ def pattern(self):
206230 'ReLU' ,
207231 inputs = [super (Conv2DBatchNormReLUQuantize , self ).pattern ()])
208232
209- def replacement (self , match_layer ):
210- relu_layer_node = match_layer
211- bn_layer_node = relu_layer_node .input_layers [0 ]
212- conv_layer_node = bn_layer_node .input_layers [0 ]
213-
233+ def _replace (self , relu_layer_node , bn_layer_node , conv_layer_node ):
214234 if _has_custom_quantize_config (
215235 relu_layer_node , bn_layer_node , conv_layer_node ):
216- return match_layer
236+ return relu_layer_node
217237
218238 conv_layer_node .layer ['config' ]['activation' ] = \
219239 keras .activations .serialize (quantize_aware_activation .NoOpActivation ())
220240 bn_layer_node .metadata ['quantize_config' ] = \
221241 default_8bit_quantize_configs .NoOpQuantizeConfig ()
222242
223- return match_layer
243+ return relu_layer_node
224244
225- def custom_objects (self ):
226- return {
227- 'NoOpQuantizeConfig' : default_8bit_quantize_configs .NoOpQuantizeConfig ,
228- 'NoOpActivation' : quantize_aware_activation .NoOpActivation
229- }
245+ def replacement (self , match_layer ):
246+ relu_layer_node = match_layer
247+ bn_layer_node = relu_layer_node .input_layers [0 ]
248+ conv_layer_node = bn_layer_node .input_layers [0 ]
249+
250+ return self ._replace (relu_layer_node , bn_layer_node , conv_layer_node )
230251
231252
232253class Conv2DBatchNormActivationQuantize (Conv2DBatchNormReLUQuantize ):
@@ -239,6 +260,34 @@ def pattern(self):
239260 inputs = [Conv2DBatchNormQuantize .pattern (self )])
240261
241262
263+ class Conv2DReshapeBatchNormReLUQuantize (Conv2DBatchNormReLUQuantize ):
264+ """Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
265+
266+ def pattern (self ):
267+ return LayerPattern (
268+ 'ReLU' ,
269+ inputs = [Conv2DReshapeBatchNormQuantize .pattern (self )])
270+
271+ def replacement (self , match_layer ):
272+ relu_layer_node = match_layer
273+ bn_layer_node = relu_layer_node .input_layers [0 ]
274+ squeeze_layer_node = bn_layer_node .input_layers [0 ]
275+ conv_layer_node = squeeze_layer_node .input_layers [0 ]
276+
277+ return self ._replace (relu_layer_node , bn_layer_node , conv_layer_node )
278+
279+
280+ class Conv2DReshapeBatchNormActivationQuantize (
281+ Conv2DReshapeBatchNormReLUQuantize ):
282+ """Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
283+
284+ def pattern (self ):
285+ return LayerPattern (
286+ 'Activation' ,
287+ config = {'activation' : 'relu' },
288+ inputs = [Conv2DReshapeBatchNormQuantize .pattern (self )])
289+
290+
242291class SeparableConv1DQuantize (transforms .Transform ):
243292 """Add QAT support for Keras SeparableConv1D layer.
244293
0 commit comments