Skip to content

Commit b5946e9

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Include SepConv layer transforms and tests
Full support for SepConv layers 1D and 2D. Also adds numerical tests to ensure validity. PiperOrigin-RevId: 324901716
1 parent d33ef45 commit b5946e9

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,13 @@ def apply(self, model, layer_quantize_map):
5252

5353
transforms = [
5454
default_8bit_transforms.InputLayerQuantize(),
55+
default_8bit_transforms.SeparableConv1DQuantize(),
5556
default_8bit_transforms.SeparableConvQuantize(),
57+
default_8bit_transforms.Conv2DReshapeBatchNormReLUQuantize(),
58+
default_8bit_transforms.Conv2DReshapeBatchNormActivationQuantize(),
5659
default_8bit_transforms.Conv2DBatchNormReLUQuantize(),
5760
default_8bit_transforms.Conv2DBatchNormActivationQuantize(),
61+
default_8bit_transforms.Conv2DReshapeBatchNormQuantize(),
5862
default_8bit_transforms.Conv2DBatchNormQuantize(),
5963
default_8bit_transforms.ConcatTransform6Inputs(),
6064
default_8bit_transforms.ConcatTransform5Inputs(),

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,39 @@ def _get_separable_conv2d_model(self):
117117
x = tf.keras.layers.ReLU()(x)
118118
return tf.keras.Model(i, x)
119119

120+
def _get_sepconv1d_bn_relu_model(self):
121+
i = tf.keras.Input(shape=(8, 3))
122+
x = tf.keras.layers.SeparableConv1D(
123+
filters=5, kernel_size=3, strides=2)(i)
124+
x = tf.keras.layers.BatchNormalization()(x)
125+
x = tf.keras.layers.ReLU()(x)
126+
return tf.keras.Model(i, x)
127+
128+
def _get_sepconv1d_bn_model(self):
129+
i = tf.keras.Input(shape=(8, 3))
130+
x = tf.keras.layers.SeparableConv1D(
131+
filters=5, kernel_size=3, strides=2)(i)
132+
x = tf.keras.layers.BatchNormalization()(x)
133+
return tf.keras.Model(i, x)
134+
135+
def _get_sepconv1d_stacked_model(self):
136+
i = tf.keras.Input(shape=(8, 3))
137+
x = tf.keras.layers.SeparableConv1D(
138+
filters=5, kernel_size=3, strides=2)(i)
139+
x = tf.keras.layers.BatchNormalization()(x)
140+
x = tf.keras.layers.SeparableConv1D(
141+
filters=5, kernel_size=3, strides=2)(x)
142+
x = tf.keras.layers.BatchNormalization()(x)
143+
x = tf.keras.layers.ReLU()(x)
144+
return tf.keras.Model(i, x)
145+
120146
@parameterized.parameters([
121147
_get_single_conv_model, _get_single_dense_model,
122148
_get_single_conv_relu_model, _get_stacked_convs_model,
123149
_get_conv_bn_relu_model, _get_depthconv_bn_relu_model,
124-
_get_separable_conv2d_model
150+
_get_separable_conv2d_model,
151+
_get_sepconv1d_bn_model, _get_sepconv1d_bn_relu_model,
152+
_get_sepconv1d_stacked_model
125153
])
126154
def testModelEndToEnd(self, model_fn):
127155
# 1. Check whether quantized model graph can be constructed.

0 commit comments

Comments
 (0)