Skip to content

Commit 939bed8

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Include SeparableConv2dQuantize in QAT
PiperOrigin-RevId: 321031264
1 parent b47e0d9 commit 939bed8

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def apply(self, model, layer_quantize_map):
5252

5353
transforms = [
5454
default_8bit_transforms.InputLayerQuantize(),
55+
default_8bit_transforms.SeparableConvQuantize(),
5556
default_8bit_transforms.Conv2DBatchNormReLUQuantize(),
5657
default_8bit_transforms.Conv2DBatchNormActivationQuantize(),
5758
default_8bit_transforms.Conv2DBatchNormQuantize(),

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,20 @@ def _get_depthconv_bn_relu_model(self):
109109
x = tf.keras.layers.ReLU()(x)
110110
return tf.keras.Model(i, x)
111111

112-
@parameterized.parameters(
112+
def _get_separable_conv2d_model(self):
113+
i = tf.keras.Input(shape=(12, 12, 3))
114+
x = tf.keras.layers.SeparableConv2D(
115+
filters=5, kernel_size=(3, 3), strides=(2, 2))(i)
116+
x = tf.keras.layers.BatchNormalization()(x)
117+
x = tf.keras.layers.ReLU()(x)
118+
return tf.keras.Model(i, x)
119+
120+
@parameterized.parameters([
113121
_get_single_conv_model, _get_single_dense_model,
114122
_get_single_conv_relu_model, _get_stacked_convs_model,
115-
_get_conv_bn_relu_model, _get_depthconv_bn_relu_model)
123+
_get_conv_bn_relu_model, _get_depthconv_bn_relu_model,
124+
_get_separable_conv2d_model
125+
])
116126
def testModelEndToEnd(self, model_fn):
117127
# 1. Check whether quantized model graph can be constructed.
118128
model = model_fn(self)
@@ -121,7 +131,7 @@ def testModelEndToEnd(self, model_fn):
121131
# 2. Sanity check to ensure basic training on random data works.
122132
x_train, y_train = self._create_test_data(model)
123133
model.compile(loss='mse', optimizer='sgd', metrics=['accuracy'])
124-
model.fit(x_train, y_train, epochs=10)
134+
model.fit(x_train, y_train, epochs=100)
125135

126136
x_test, y_test = self._create_test_data(model)
127137

0 commit comments

Comments
 (0)