Skip to content

Commit 1fbc2ed

Browse files
Xharktensorflower-gardener
authored andcommitted
Add Upsampling2D supports.
There are two interpolation mode for UpSampling2D op. (nearest and bilinear) ResizeNearest just passed numerical test, but ResizeBilinear has larger quantization error for now. PiperOrigin-RevId: 340558630
1 parent a5dfe1a commit 1fbc2ed

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,21 @@ class QuantizeRegistry(quantize_registry.QuantizeRegistry, _RNNHelper):
9797
_no_quantize(layers.Cropping2D),
9898
_no_quantize(layers.Cropping3D),
9999
# _no_quantize(layers.UpSampling1D),
100-
# _no_quantize(layers.UpSampling2D),
100+
101+
# TODO(tfmot): Reduce the quantization errors for bilinear interpolation
102+
# type for UpSampling2D op. UpSampling2D supports two interpolation types,
103+
# nearest and bilinear. we convert the op to ResizeBilnear integer op on
104+
# TFLite. This ResizeBilinear TFLite op only for input and output has the
105+
# same quantization parameters. (scale and zero_point) To do that, The
106+
# TFLite converter inserts quantization cast op right after the input to
107+
# match quantization params for the output. Current QAT doesn’t consider
108+
# this behavior yet, so now we have larger quantization errors than we
109+
# expected. We have to add support for it on QAT or change the TFLite
110+
# kernel op to support different quantization params for input and output.
111+
# (Note that the nearest case just copies the number so there’s no more
112+
# errors even if the quantization order is different.)
113+
_QuantizeInfo(layers.UpSampling2D, [], [], True),
114+
101115
# _no_quantize(layers.UpSampling3D),
102116
_no_quantize(layers.ZeroPadding1D),
103117
_no_quantize(layers.ZeroPadding2D),

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,28 @@ def _get_sepconv1d_stacked_model(self):
143143
x = tf.keras.layers.ReLU()(x)
144144
return tf.keras.Model(i, x)
145145

146+
def _get_upsampling2d_nearest_model(self):
147+
i = tf.keras.Input(shape=(32, 32, 3))
148+
x = tf.keras.layers.UpSampling2D(size=(3, 4), interpolation='nearest')(i)
149+
return tf.keras.Model(i, x)
150+
151+
def _get_upsampling2d_bilinear_model(self):
152+
i = tf.keras.Input(shape=(1, 3, 1))
153+
x = tf.keras.layers.UpSampling2D(size=(1, 5), interpolation='bilinear')(i)
154+
return tf.keras.Model(i, x)
155+
146156
@parameterized.parameters([
147157
_get_single_conv_model, _get_single_dense_model,
148158
_get_single_conv_relu_model, _get_stacked_convs_model,
149159
_get_conv_bn_relu_model, _get_depthconv_bn_relu_model,
150160
_get_separable_conv2d_model,
151161
_get_sepconv1d_bn_model, _get_sepconv1d_bn_relu_model,
152-
_get_sepconv1d_stacked_model
162+
_get_sepconv1d_stacked_model,
163+
_get_upsampling2d_nearest_model,
164+
# _get_upsampling2d_bilinear_model
165+
# TODO(tfmot): There are gaps between ResizeBilinear with FakeQuant and
166+
# TFLite quantized ResizeBilinear op. It has a bit more quantization
167+
# error than other ops in this test now.
153168
])
154169
def testModelEndToEnd(self, model_fn):
155170
# 1. Check whether quantized model graph can be constructed.

tensorflow_model_optimization/python/core/quantization/keras/quantize_functional_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ class QuantizeFullIntegerModelTest(tf.test.TestCase, parameterized.TestCase):
141141
}),
142142
(layers.UpSampling2D, {
143143
'input_shape': (4, 6, 1),
144+
'interpolation': 'nearest',
145+
}),
146+
(layers.UpSampling2D, {
147+
'input_shape': (4, 6, 1),
148+
'interpolation': 'bilinear',
144149
}),
145150
(layers.UpSampling3D, {
146151
'input_shape': (4, 6, 1),
@@ -266,7 +271,6 @@ class QuantizeFullIntegerModelTest(tf.test.TestCase, parameterized.TestCase):
266271
@parameterized.parameters([
267272
l for l in _LAYER_PARAMS if l[0] not in [
268273
# Not done since TFLite converter doesn't support in TF2 yet.
269-
layers.UpSampling2D,
270274
layers.Conv3D,
271275
layers.Conv3DTranspose,
272276
layers.AveragePooling3D,

0 commit comments

Comments
 (0)