Skip to content

Commit 21aac43

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Support for SeparableConv2D for QAT
SeparableConv2D is a combination of DepthwiseConv2D and Conv2D. QAT needs to break the layer up into this combination so that the rest of the infra can then apply QAT to the resulting layers. Not possible to implement support for SeparableConv directly since it does not provide hooks to alter internal graph construction. PiperOrigin-RevId: 321030603
1 parent ee53c9a commit 21aac43

File tree

3 files changed

+157
-1
lines changed

3 files changed

+157
-1
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,105 @@ def pattern(self):
228228
inputs=[Conv2DBatchNormQuantize.pattern(self)])
229229

230230

231+
class SeparableConvQuantize(transforms.Transform):
232+
"""Break SeparableConv into a DepthwiseConv and Conv layer.
233+
234+
SeparableConv is a composition of a DepthwiseConv and a Conv layer. For the
235+
purpose of quantization, a FQ operation needs to be placed between the output
236+
of DepthwiseConv and the following Conv.
237+
238+
This is needed since there is a dynamic tensor in between the two layers, and
239+
it's range information needs to be captured by the FakeQuant op to ensure
240+
full int8 quantization of the layers is possible.
241+
242+
Splitting the layer into 2 ensures that each individual layer is handled
243+
correctly with respect to quantization.
244+
"""
245+
246+
def pattern(self):
247+
return LayerPattern('SeparableConv2D')
248+
249+
@staticmethod
250+
def _get_quantize_config(layer_node):
251+
return layer_node.metadata.get('quantize_config')
252+
253+
def _has_custom_quantize_config(self, *layer_nodes):
254+
for layer_node in layer_nodes:
255+
if self._get_quantize_config(layer_node) is not None:
256+
return True
257+
return False
258+
259+
def replacement(self, match_layer):
260+
if self._has_custom_quantize_config(match_layer):
261+
return match_layer
262+
263+
sepconv_layer = match_layer.layer
264+
sepconv_weights = list(match_layer.weights.values())
265+
266+
# TODO(pulkitb): SeparableConv has kwargs other than constructor args which
267+
# need to be handled.
268+
# Applicable to both layers: trainable, dtype, name
269+
# Applicable to dconv: input_dim, input_shape, batch_input_shape, batch_size
270+
# Needs special handling: weights
271+
# Unknown: dynamic, autocast
272+
273+
dconv_layer = tf.keras.layers.DepthwiseConv2D(
274+
kernel_size=sepconv_layer['config']['kernel_size'],
275+
strides=sepconv_layer['config']['strides'],
276+
padding=sepconv_layer['config']['padding'],
277+
depth_multiplier=sepconv_layer['config']['depth_multiplier'],
278+
data_format=sepconv_layer['config']['data_format'],
279+
dilation_rate=sepconv_layer['config']['dilation_rate'],
280+
activation=None,
281+
use_bias=False,
282+
depthwise_initializer=sepconv_layer['config']['depthwise_initializer'],
283+
depthwise_regularizer=sepconv_layer['config']['depthwise_regularizer'],
284+
depthwise_constraint=sepconv_layer['config']['depthwise_constraint'],
285+
trainable=sepconv_layer['config']['trainable']
286+
)
287+
dconv_weights = collections.OrderedDict()
288+
dconv_weights['depthwise_kernel:0'] = sepconv_weights[0]
289+
dconv_layer_config = keras.layers.serialize(dconv_layer)
290+
dconv_layer_config['name'] = dconv_layer.name
291+
# Needed to ensure these new layers are considered for quantization.
292+
dconv_metadata = {'quantize_config': None}
293+
294+
conv_layer = tf.keras.layers.Conv2D(
295+
filters=sepconv_layer['config']['filters'],
296+
kernel_size=(1, 1), # (1,) * rank
297+
strides=(1, 1),
298+
padding='valid',
299+
data_format=sepconv_layer['config']['data_format'],
300+
dilation_rate=sepconv_layer['config']['dilation_rate'],
301+
groups=1,
302+
activation=sepconv_layer['config']['activation'],
303+
use_bias=sepconv_layer['config']['use_bias'],
304+
kernel_initializer=sepconv_layer['config']['pointwise_initializer'],
305+
bias_initializer=sepconv_layer['config']['bias_initializer'],
306+
kernel_regularizer=sepconv_layer['config']['pointwise_regularizer'],
307+
bias_regularizer=sepconv_layer['config']['bias_regularizer'],
308+
activity_regularizer=sepconv_layer['config']['activity_regularizer'],
309+
kernel_constraint=sepconv_layer['config']['pointwise_constraint'],
310+
bias_constraint=sepconv_layer['config']['bias_constraint'],
311+
trainable=sepconv_layer['config']['trainable']
312+
)
313+
conv_weights = collections.OrderedDict()
314+
conv_weights['kernel:0'] = sepconv_weights[1]
315+
conv_weights['bias:0'] = sepconv_weights[2]
316+
conv_layer_config = keras.layers.serialize(conv_layer)
317+
conv_layer_config['name'] = conv_layer.name
318+
# Needed to ensure these new layers are considered for quantization.
319+
conv_metadata = {'quantize_config': None}
320+
321+
dconv_layer_node = LayerNode(
322+
dconv_layer_config, weights=dconv_weights, metadata=dconv_metadata)
323+
return LayerNode(
324+
conv_layer_config,
325+
weights=conv_weights,
326+
input_layers=[dconv_layer_node],
327+
metadata=conv_metadata)
328+
329+
231330
class AddReLUQuantize(transforms.Transform):
232331
"""Ensure FQ does not get placed between Add and ReLU."""
233332

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,63 @@ def testConv2DBatchNormReLUQuantize(
221221
self.assertAllClose(
222222
transformed_model.predict(inputs), model.predict(inputs))
223223

224+
@parameterized.named_parameters(
225+
('padding_valid', {'padding': 'valid'}),
226+
('padding_same', {'padding': 'same'}),
227+
('padding_same_dilation_2', {'padding': 'same', 'dilation_rate': 2}),
228+
('strides', {'strides': 2}),
229+
('dilation_rate', {'dilation_rate': 2}),
230+
('depth_multiplier', {'depth_multiplier': 2}),
231+
('regularizer', {
232+
'depthwise_regularizer': 'l2',
233+
'pointwise_regularizer': 'l2',
234+
'bias_regularizer': 'l2',
235+
'activity_regularizer': 'l2'}),
236+
('constraint', {
237+
'depthwise_constraint': tf.keras.constraints.max_norm(2.),
238+
'pointwise_constraint': tf.keras.constraints.min_max_norm(0., 2.),
239+
'bias_constraint': tf.keras.constraints.unit_norm()})
240+
)
241+
def testSeparableConvQuantize_(self, kwargs):
242+
kwargs['filters'] = 2
243+
kwargs['kernel_size'] = 3
244+
num_samples = 2
245+
stack_size = 3
246+
num_row = 7
247+
num_col = 6
248+
249+
sepconv_model = tf.keras.Sequential([
250+
tf.keras.Input(
251+
shape=(num_row, num_col, stack_size), batch_size=num_samples),
252+
tf.keras.layers.SeparableConv2D(**kwargs)])
253+
254+
transformed_model, updated_metadata = ModelTransformer(
255+
sepconv_model,
256+
[default_8bit_transforms.SeparableConvQuantize()],
257+
).transform()
258+
259+
self.assertContainsSubset(
260+
updated_metadata.keys(), {'depthwise_conv2d', 'conv2d'})
261+
# Transformed model should have the same output shape
262+
self.assertEqual(sepconv_model.output_shape, transformed_model.output_shape)
263+
264+
x = np.random.rand(*sepconv_model.input_shape)
265+
y = np.random.rand(*sepconv_model.output_shape)
266+
267+
# Ensure model is equivalent, and forward pass results are the same.
268+
self.assertAllClose(sepconv_model.predict(x), transformed_model.predict(x))
269+
270+
# Ensure model is equivalent, and training results are the same.
271+
sepconv_model.compile(loss='categorical_crossentropy', optimizer='sgd')
272+
sepconv_model.fit(x, y, epochs=100)
273+
transformed_model.compile(loss='categorical_crossentropy', optimizer='sgd')
274+
transformed_model.fit(x, y, epochs=100)
275+
276+
# Over a long training cycle with constraints and regularizers, the model
277+
# can build very minute differences. Hence reducing tol to 1e-5.
278+
self.assertAllClose(sepconv_model.predict(x), transformed_model.predict(x),
279+
atol=1e-5, rtol=1e-5)
280+
224281
@parameterized.parameters(
225282
('relu', default_8bit_transforms.AddReLUQuantize),
226283
('act_relu', default_8bit_transforms.AddActivationQuantize),

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class LayerPattern(object):
4141
pattern = LayerPattern('Concat', {}, [
4242
LayerPattern('Conv2D', {}, []),
4343
LayerPattern('Conv2D', {}, [])
44-
)
44+
])
4545
"""
4646

4747
def __init__(self, class_name, config=None, inputs=None):

0 commit comments

Comments
 (0)