@@ -228,6 +228,105 @@ def pattern(self):
228
228
inputs = [Conv2DBatchNormQuantize .pattern (self )])
229
229
230
230
231
+ class SeparableConvQuantize (transforms .Transform ):
232
+ """Break SeparableConv into a DepthwiseConv and Conv layer.
233
+
234
+ SeparableConv is a composition of a DepthwiseConv and a Conv layer. For the
235
+ purpose of quantization, a FQ operation needs to be placed between the output
236
+ of DepthwiseConv and the following Conv.
237
+
238
+ This is needed since there is a dynamic tensor in between the two layers, and
239
+ it's range information needs to be captured by the FakeQuant op to ensure
240
+ full int8 quantization of the layers is possible.
241
+
242
+ Splitting the layer into 2 ensures that each individual layer is handled
243
+ correctly with respect to quantization.
244
+ """
245
+
246
+ def pattern (self ):
247
+ return LayerPattern ('SeparableConv2D' )
248
+
249
+ @staticmethod
250
+ def _get_quantize_config (layer_node ):
251
+ return layer_node .metadata .get ('quantize_config' )
252
+
253
+ def _has_custom_quantize_config (self , * layer_nodes ):
254
+ for layer_node in layer_nodes :
255
+ if self ._get_quantize_config (layer_node ) is not None :
256
+ return True
257
+ return False
258
+
259
+ def replacement (self , match_layer ):
260
+ if self ._has_custom_quantize_config (match_layer ):
261
+ return match_layer
262
+
263
+ sepconv_layer = match_layer .layer
264
+ sepconv_weights = list (match_layer .weights .values ())
265
+
266
+ # TODO(pulkitb): SeparableConv has kwargs other than constructor args which
267
+ # need to be handled.
268
+ # Applicable to both layers: trainable, dtype, name
269
+ # Applicable to dconv: input_dim, input_shape, batch_input_shape, batch_size
270
+ # Needs special handling: weights
271
+ # Unknown: dynamic, autocast
272
+
273
+ dconv_layer = tf .keras .layers .DepthwiseConv2D (
274
+ kernel_size = sepconv_layer ['config' ]['kernel_size' ],
275
+ strides = sepconv_layer ['config' ]['strides' ],
276
+ padding = sepconv_layer ['config' ]['padding' ],
277
+ depth_multiplier = sepconv_layer ['config' ]['depth_multiplier' ],
278
+ data_format = sepconv_layer ['config' ]['data_format' ],
279
+ dilation_rate = sepconv_layer ['config' ]['dilation_rate' ],
280
+ activation = None ,
281
+ use_bias = False ,
282
+ depthwise_initializer = sepconv_layer ['config' ]['depthwise_initializer' ],
283
+ depthwise_regularizer = sepconv_layer ['config' ]['depthwise_regularizer' ],
284
+ depthwise_constraint = sepconv_layer ['config' ]['depthwise_constraint' ],
285
+ trainable = sepconv_layer ['config' ]['trainable' ]
286
+ )
287
+ dconv_weights = collections .OrderedDict ()
288
+ dconv_weights ['depthwise_kernel:0' ] = sepconv_weights [0 ]
289
+ dconv_layer_config = keras .layers .serialize (dconv_layer )
290
+ dconv_layer_config ['name' ] = dconv_layer .name
291
+ # Needed to ensure these new layers are considered for quantization.
292
+ dconv_metadata = {'quantize_config' : None }
293
+
294
+ conv_layer = tf .keras .layers .Conv2D (
295
+ filters = sepconv_layer ['config' ]['filters' ],
296
+ kernel_size = (1 , 1 ), # (1,) * rank
297
+ strides = (1 , 1 ),
298
+ padding = 'valid' ,
299
+ data_format = sepconv_layer ['config' ]['data_format' ],
300
+ dilation_rate = sepconv_layer ['config' ]['dilation_rate' ],
301
+ groups = 1 ,
302
+ activation = sepconv_layer ['config' ]['activation' ],
303
+ use_bias = sepconv_layer ['config' ]['use_bias' ],
304
+ kernel_initializer = sepconv_layer ['config' ]['pointwise_initializer' ],
305
+ bias_initializer = sepconv_layer ['config' ]['bias_initializer' ],
306
+ kernel_regularizer = sepconv_layer ['config' ]['pointwise_regularizer' ],
307
+ bias_regularizer = sepconv_layer ['config' ]['bias_regularizer' ],
308
+ activity_regularizer = sepconv_layer ['config' ]['activity_regularizer' ],
309
+ kernel_constraint = sepconv_layer ['config' ]['pointwise_constraint' ],
310
+ bias_constraint = sepconv_layer ['config' ]['bias_constraint' ],
311
+ trainable = sepconv_layer ['config' ]['trainable' ]
312
+ )
313
+ conv_weights = collections .OrderedDict ()
314
+ conv_weights ['kernel:0' ] = sepconv_weights [1 ]
315
+ conv_weights ['bias:0' ] = sepconv_weights [2 ]
316
+ conv_layer_config = keras .layers .serialize (conv_layer )
317
+ conv_layer_config ['name' ] = conv_layer .name
318
+ # Needed to ensure these new layers are considered for quantization.
319
+ conv_metadata = {'quantize_config' : None }
320
+
321
+ dconv_layer_node = LayerNode (
322
+ dconv_layer_config , weights = dconv_weights , metadata = dconv_metadata )
323
+ return LayerNode (
324
+ conv_layer_config ,
325
+ weights = conv_weights ,
326
+ input_layers = [dconv_layer_node ],
327
+ metadata = conv_metadata )
328
+
329
+
231
330
class AddReLUQuantize (transforms .Transform ):
232
331
"""Ensure FQ does not get placed between Add and ReLU."""
233
332
0 commit comments