Skip to content

Commit eaeb6e7

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Add support for Relu folding in Conv/DConv/Dense
This makes ReLU6 support better since users typically use 'relu' directly. PiperOrigin-RevId: 364423101
1 parent b78818c commit eaeb6e7

File tree

3 files changed

+36
-8
lines changed

3 files changed

+36
-8
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def apply(self, model, layer_quantize_map):
6565
default_8bit_transforms.ConcatTransform4Inputs(),
6666
default_8bit_transforms.ConcatTransform3Inputs(),
6767
default_8bit_transforms.ConcatTransform(),
68-
default_8bit_transforms.AddReLUQuantize(),
69-
default_8bit_transforms.AddActivationQuantize(),
68+
default_8bit_transforms.LayerReLUQuantize(),
69+
default_8bit_transforms.LayerReluActivationQuantize(),
7070
]
7171
return model_transformer.ModelTransformer(
7272
model, transforms,

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,11 +497,12 @@ def replacement(self, match_layer):
497497
metadata=conv_metadata)
498498

499499

500-
class AddReLUQuantize(transforms.Transform):
500+
class LayerReLUQuantize(transforms.Transform):
501501
"""Ensure FQ does not get placed between Add and ReLU."""
502502

503503
def pattern(self):
504-
return LayerPattern('ReLU', inputs=[LayerPattern('Add')])
504+
return LayerPattern(
505+
'ReLU', inputs=[LayerPattern('Add|Conv2D|DepthwiseConv2D|Dense')])
505506

506507
def replacement(self, match_layer):
507508
relu_layer_node = match_layer
@@ -518,14 +519,14 @@ def custom_objects(self):
518519
}
519520

520521

521-
class AddActivationQuantize(AddReLUQuantize):
522+
class LayerReluActivationQuantize(LayerReLUQuantize):
522523
"""Ensure FQ does not get placed between Add and ReLU."""
523524

524525
def pattern(self):
525526
return LayerPattern(
526527
'Activation',
527528
config={'activation': 'relu'},
528-
inputs=[LayerPattern('Add')])
529+
inputs=[LayerPattern('Add|Conv2D|DepthwiseConv2D|Dense')])
529530

530531

531532
class InputLayerQuantize(transforms.Transform):

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,8 @@ def testSeparableConvQuantize_(self, kwargs):
344344
# Conv2DReshapeBatchNormActivationQuantize
345345

346346
@parameterized.parameters(
347-
('relu', default_8bit_transforms.AddReLUQuantize),
348-
('act_relu', default_8bit_transforms.AddActivationQuantize),
347+
('relu', default_8bit_transforms.LayerReLUQuantize),
348+
('act_relu', default_8bit_transforms.LayerReluActivationQuantize),
349349
)
350350
def testAddReLUQuantize(self, activation_type, transform_type):
351351
add = keras.layers.Add()
@@ -370,6 +370,33 @@ def testAddReLUQuantize(self, activation_type, transform_type):
370370
updated_metadata.get(add_layer.name).get('quantize_config'),
371371
default_8bit_quantize_configs.NoOpQuantizeConfig)
372372

373+
@parameterized.parameters(
374+
('relu', default_8bit_transforms.LayerReLUQuantize),
375+
('act_relu', default_8bit_transforms.LayerReluActivationQuantize))
376+
def testLayerReLUQuantize(self, activation_type, transform_type):
377+
# TODO(tfmot): Add tests for DepthConv and Dense
378+
input_shape = (1, 3, 3, 3)
379+
conv_layer = tf.keras.layers.Conv2D(5, 2, input_shape=input_shape)
380+
if activation_type == 'relu':
381+
act_layer = keras.layers.ReLU(6.0)
382+
elif activation_type == 'act_relu':
383+
act_layer = keras.layers.Activation('relu')
384+
385+
model = tf.keras.Sequential([conv_layer, act_layer])
386+
387+
transformed_model, updated_metadata = ModelTransformer(
388+
model,
389+
[transform_type()],
390+
).transform()
391+
392+
self.assertIsInstance(
393+
updated_metadata.get(model.layers[0].name).get('quantize_config'),
394+
default_8bit_quantize_configs.NoOpQuantizeConfig)
395+
396+
inputs = np.random.standard_normal(input_shape)
397+
self.assertAllClose(
398+
transformed_model.predict(inputs), model.predict(inputs))
399+
373400
def testAddsQuantizeLayerAfterInputLayer(self):
374401
inp1 = keras.layers.Input((3,))
375402
inp2 = keras.layers.Input((3,))

0 commit comments

Comments
 (0)