Skip to content

Commit ee53c9a

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Test QAT results numerically to ensure correctness
Adds a test to verify the QAT forward pass values for various models/layers to ensure they match TFLite results. PiperOrigin-RevId: 320452756
1 parent fe5eb7d commit ee53c9a

File tree

2 files changed

+158
-0
lines changed

2 files changed

+158
-0
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,17 @@ py_library(
136136
"//tensorflow_model_optimization/python/core/quantization/keras/graph_transformations:model_transformer",
137137
],
138138
)
139+
140+
py_test(
141+
name = "quantize_numerical_test",
142+
srcs = ["quantize_numerical_test.py"],
143+
python_version = "PY3",
144+
deps = [
145+
# absl/testing:parameterized dep1,
146+
# numpy dep1,
147+
# tensorflow dep1,
148+
"//tensorflow_model_optimization/python/core/keras:test_utils",
149+
"//tensorflow_model_optimization/python/core/quantization/keras:quantize",
150+
"//tensorflow_model_optimization/python/core/quantization/keras:utils",
151+
],
152+
)
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Numerical verification tests for QAT."""
16+
17+
18+
import tempfile
19+
20+
from absl.testing import parameterized
21+
22+
import numpy as np
23+
import tensorflow as tf
24+
25+
from tensorflow.python.keras import keras_parameterized
26+
from tensorflow_model_optimization.python.core.quantization.keras import quantize
27+
from tensorflow_model_optimization.python.core.quantization.keras import utils
28+
29+
30+
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
31+
class QuantizeNumericalTest(tf.test.TestCase, parameterized.TestCase):
32+
33+
@staticmethod
34+
def _batch(dims, batch_size):
35+
if dims[0] is None:
36+
dims[0] = batch_size
37+
return dims
38+
39+
def _create_test_data(self, model):
40+
x = np.random.randn(
41+
*self._batch(model.input.get_shape().as_list(), 1)).astype('float32')
42+
y = np.random.randn(
43+
*self._batch(model.output.get_shape().as_list(), 1)).astype('float32')
44+
45+
return x, y
46+
47+
@staticmethod
48+
def _execute_tflite(tflite_file, x_test, y_test):
49+
interpreter = tf.lite.Interpreter(model_path=tflite_file)
50+
interpreter.allocate_tensors()
51+
52+
input_index = interpreter.get_input_details()[0]['index']
53+
output_index = interpreter.get_output_details()[0]['index']
54+
55+
for x, _ in zip(x_test, y_test):
56+
x = x.reshape((1,) + x.shape)
57+
interpreter.set_tensor(input_index, x)
58+
interpreter.invoke()
59+
y_ = interpreter.get_tensor(output_index)
60+
61+
return y_
62+
63+
def _get_single_conv_model(self):
64+
i = tf.keras.Input(shape=(32, 32, 3))
65+
x = tf.keras.layers.Conv2D(2, kernel_size=(3, 3), strides=(2, 2))(i)
66+
return tf.keras.Model(i, x)
67+
68+
def _get_single_dense_model(self):
69+
i = tf.keras.Input(shape=(5,))
70+
x = tf.keras.layers.Dense(3)(i)
71+
return tf.keras.Model(i, x)
72+
73+
def _get_single_conv_relu_model(self):
74+
i = tf.keras.Input(shape=(6, 6, 3))
75+
x = tf.keras.layers.Conv2D(
76+
2, kernel_size=(3, 3), strides=(2, 2), activation='relu')(i)
77+
x = tf.keras.layers.ReLU()(x)
78+
return tf.keras.Model(i, x)
79+
80+
def _get_stacked_convs_model(self):
81+
i = tf.keras.Input(shape=(64, 64, 3))
82+
x = tf.keras.layers.Conv2D(
83+
10, kernel_size=(3, 3), strides=(1, 1), activation='relu')(i)
84+
x = tf.keras.layers.Conv2D(
85+
# Setting strides to (1, 1) passes test, (2, 2) fails test?
86+
# Somehow one value is at border.
87+
# Train over 100 epochs, and issue goes away.
88+
# Why are all the first values zero?
89+
10, kernel_size=(3, 3), strides=(2, 2), activation='relu')(x)
90+
x = tf.keras.layers.Conv2D(
91+
10, kernel_size=(3, 3), strides=(2, 2), activation='relu')(x)
92+
x = tf.keras.layers.Conv2D(
93+
5, kernel_size=(3, 3), strides=(2, 2), activation='relu')(x)
94+
x = tf.keras.layers.Conv2D(
95+
2, kernel_size=(3, 3), strides=(2, 2), activation='relu')(x)
96+
return tf.keras.Model(i, x)
97+
98+
def _get_conv_bn_relu_model(self):
99+
i = tf.keras.Input(shape=(6, 6, 3))
100+
x = tf.keras.layers.Conv2D(3, kernel_size=(3, 3), strides=(2, 2))(i)
101+
x = tf.keras.layers.BatchNormalization()(x)
102+
x = tf.keras.layers.ReLU()(x)
103+
return tf.keras.Model(i, x)
104+
105+
def _get_depthconv_bn_relu_model(self):
106+
i = tf.keras.Input(shape=(6, 6, 3))
107+
x = tf.keras.layers.DepthwiseConv2D(kernel_size=(3, 3), strides=(2, 2))(i)
108+
x = tf.keras.layers.BatchNormalization()(x)
109+
x = tf.keras.layers.ReLU()(x)
110+
return tf.keras.Model(i, x)
111+
112+
@parameterized.parameters(
113+
_get_single_conv_model, _get_single_dense_model,
114+
_get_single_conv_relu_model, _get_stacked_convs_model,
115+
_get_conv_bn_relu_model, _get_depthconv_bn_relu_model)
116+
def testModelEndToEnd(self, model_fn):
117+
# 1. Check whether quantized model graph can be constructed.
118+
model = model_fn(self)
119+
model = quantize.quantize_model(model)
120+
121+
# 2. Sanity check to ensure basic training on random data works.
122+
x_train, y_train = self._create_test_data(model)
123+
model.compile(loss='mse', optimizer='sgd', metrics=['accuracy'])
124+
model.fit(x_train, y_train, epochs=10)
125+
126+
x_test, y_test = self._create_test_data(model)
127+
128+
y_tf = model.predict(x_test)
129+
130+
# 3. Ensure conversion to TFLite works.
131+
_, tflite_file = tempfile.mkstemp('.tflite')
132+
print('TFLite File: ', tflite_file)
133+
with quantize.quantize_scope():
134+
utils.convert_keras_to_tflite(model, tflite_file)
135+
136+
# 4. Verify input runs on converted model.
137+
y_tfl = self._execute_tflite(tflite_file, x_test, y_test)
138+
139+
# 5. Verify results are the same in TF and TFL.
140+
self.assertAllClose(y_tf, y_tfl)
141+
142+
143+
if __name__ == '__main__':
144+
tf.test.main()

0 commit comments

Comments
 (0)