Skip to content

Commit 3fec5a7

Browse files
Add transforms to ensure no faked quantization between Add and ReLU.
Without the transform, the quantized model convert to tensorflow lite as: Inputs Outputs builtin_options opcode [169, 170, 37] [171] {'fused_activation_function': 1, ...} CONV_2D [171, 172, 70] [173] {...} DEPTHWISE_CONV_2D [173, 174, 38] [175] {...} CONV_2D [169, 175] [176] {'fused_activation_function': 0} ADD [176] [177] None RELU [177] [178] None QUANTIZE With the transform, RELU fused to ADD as: Inputs Outputs builtin_options opcode [136, 137, 33] [138] {“fused_activation_function': 1, ...} CONV_2D [138, 139, 68] [140] {...} DEPTHWISE_CONV_2D [140, 141, 34] [142] {...} CONV_2D [136, 142] [143] {'fused_activation_function': 1} ADD PiperOrigin-RevId: 315794544
1 parent be4ea2d commit 3fec5a7

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ def apply(self, model, layer_quantize_map):
6060
default_8bit_transforms.ConcatTransform4Inputs(),
6161
default_8bit_transforms.ConcatTransform3Inputs(),
6262
default_8bit_transforms.ConcatTransform(),
63+
default_8bit_transforms.AddReLUQuantize(),
64+
default_8bit_transforms.AddActivationQuantize(),
6365
]
64-
6566
return model_transformer.ModelTransformer(
6667
model, transforms,
6768
layer_quantize_map.keys(), layer_quantize_map).transform()

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,37 @@ def pattern(self):
228228
inputs=[Conv2DBatchNormQuantize.pattern(self)])
229229

230230

231+
class AddReLUQuantize(transforms.Transform):
232+
"""Ensure FQ does not get placed between Add and ReLU."""
233+
234+
def pattern(self):
235+
return LayerPattern('ReLU', inputs=[LayerPattern('Add')])
236+
237+
def replacement(self, match_layer):
238+
relu_layer_node = match_layer
239+
add_layer_node = relu_layer_node.input_layers[0]
240+
241+
add_layer_node.metadata['quantize_config'] = \
242+
default_8bit_quantize_configs.NoOpQuantizeConfig()
243+
244+
return match_layer
245+
246+
def custom_objects(self):
247+
return {
248+
'NoOpQuantizeConfig': default_8bit_quantize_configs.NoOpQuantizeConfig,
249+
}
250+
251+
252+
class AddActivationQuantize(AddReLUQuantize):
253+
"""Ensure FQ does not get placed between Add and ReLU."""
254+
255+
def pattern(self):
256+
return LayerPattern(
257+
'Activation',
258+
config={'activation': 'relu'},
259+
inputs=[LayerPattern('Add')])
260+
261+
231262
class InputLayerQuantize(transforms.Transform):
232263
"""Quantizes InputLayer, by adding QuantizeLayer after it.
233264

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,33 @@ def testConv2DBatchNormReLUQuantize(
221221
self.assertAllClose(
222222
transformed_model.predict(inputs), model.predict(inputs))
223223

224+
@parameterized.parameters(
225+
('relu', default_8bit_transforms.AddReLUQuantize),
226+
('act_relu', default_8bit_transforms.AddActivationQuantize),
227+
)
228+
def testAddReLUQuantize(self, activation_type, transform_type):
229+
add = keras.layers.Add()
230+
if activation_type == 'relu':
231+
activation = keras.layers.ReLU(6.0)
232+
elif activation_type == 'act_relu':
233+
activation = keras.layers.Activation('relu')
234+
235+
inp1 = keras.layers.Input((3,))
236+
inp2 = keras.layers.Input((3,))
237+
x = activation(add([inp1, inp2]))
238+
model = keras.Model([inp1, inp2], x)
239+
240+
transformed_model, updated_metadata = ModelTransformer(
241+
model,
242+
[transform_type()],
243+
).transform()
244+
245+
add_layer = transformed_model.layers[2]
246+
247+
self.assertIsInstance(
248+
updated_metadata.get(add_layer.name).get('quantize_config'),
249+
default_8bit_quantize_configs.NoOpQuantizeConfig)
250+
224251
def testAddsQuantizeLayerAfterInputLayer(self):
225252
inp1 = keras.layers.Input((3,))
226253
inp2 = keras.layers.Input((3,))

0 commit comments

Comments
 (0)