Skip to content

Commit d384442

Browse files
alanchiaotensorflower-gardener
authored andcommitted
Add initial layer tests to validate conversion for full-integer quantization.
There are three categories of layers that prevent a supported and recommended path to deployment: 1) Layers that we need to place FakeQuants properly 2) Layers that we need to make into per-axis according to the scheme 3) Layers that are not supported in TFLite (whether from lack of float or quantized support). PiperOrigin-RevId: 305133601
1 parent b0ab25f commit d384442

File tree

2 files changed

+332
-3
lines changed

2 files changed

+332
-3
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quantize_functional_test.py

Lines changed: 324 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import tempfile
2222

2323
from absl.testing import parameterized
24-
24+
import numpy as np
2525
import tensorflow as tf
2626

2727
# TODO(b/139939526): move to public API.
@@ -31,6 +31,8 @@
3131
from tensorflow_model_optimization.python.core.quantization.keras import quantize
3232
from tensorflow_model_optimization.python.core.quantization.keras import utils as test_utils
3333

34+
layers = tf.keras.layers
35+
3436

3537
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
3638
class QuantizeFunctionalTest(tf.test.TestCase, parameterized.TestCase):
@@ -80,5 +82,326 @@ def testQuantizesMnist(self):
8082
rtol=0.2, atol=0.2)
8183

8284

85+
# Set of tests to determine what we can include in the whitelisted layers
86+
# for the default API.
87+
#
88+
# TFLite in TF 2.X currently does not support creation of full-integer models.
89+
# However, having every layer pass these tests ensures that the resulting
90+
# quantization-aware trained model will have a path to deployment once
91+
# TFLite adds support.
92+
#
93+
# Note these tests are not perfect yet.
94+
# 1. Some Keras layers use different
95+
# TensorFlow ops depending on the initialization parameters. This
96+
# tests the most noticable ones, but unlikely all.
97+
#
98+
# TODO(tfmot): merge with test class above when run_all_keras_modes works
99+
# with V1.
100+
class QuantizeFullIntegerModelTest(tf.test.TestCase, parameterized.TestCase):
101+
102+
_LAYER_PARAMS = [
103+
(layers.ReLU, {}),
104+
(layers.Softmax, {}),
105+
(layers.Conv1D, {
106+
'input_shape': (3, 6),
107+
'filters': 4,
108+
'kernel_size': 2,
109+
}),
110+
(layers.Conv2D, {
111+
'input_shape': (4, 6, 1),
112+
'filters': 4,
113+
'kernel_size': (2, 2)
114+
}),
115+
(layers.Conv3D, {
116+
'input_shape': (3, 4, 6, 1),
117+
'filters': 4,
118+
'kernel_size': (2, 2, 2)
119+
}),
120+
(layers.Conv2DTranspose, {
121+
'input_shape': (4, 6, 1),
122+
'filters': 4,
123+
'kernel_size': (2, 2)
124+
}),
125+
(layers.Conv3DTranspose, {
126+
'input_shape': (3, 4, 6, 1),
127+
'filters': 4,
128+
'kernel_size': (2, 2, 2)
129+
}),
130+
(layers.Cropping1D, {
131+
'input_shape': (3, 6),
132+
}),
133+
(layers.Cropping2D, {
134+
'input_shape': (4, 6, 1),
135+
}),
136+
(layers.Cropping3D, {
137+
'input_shape': (3, 4, 6, 1),
138+
}),
139+
(layers.UpSampling1D, {
140+
'input_shape': (3, 6)
141+
}),
142+
(layers.UpSampling2D, {
143+
'input_shape': (4, 6, 1),
144+
}),
145+
(layers.UpSampling3D, {
146+
'input_shape': (4, 6, 1),
147+
}),
148+
(layers.ZeroPadding1D, {
149+
'input_shape': (3, 6),
150+
}),
151+
(layers.ZeroPadding2D, {
152+
'input_shape': (4, 6, 1),
153+
}),
154+
(layers.ZeroPadding3D, {
155+
'input_shape': (3, 4, 6, 1),
156+
}),
157+
(layers.ActivityRegularization, {}),
158+
(layers.Dense, {
159+
'units': 2
160+
}),
161+
(layers.Dropout, {
162+
'rate': 0.2
163+
}),
164+
(layers.Flatten, {}),
165+
(layers.Masking, {}),
166+
(layers.Permute, {
167+
'input_shape': (10, 64),
168+
'dims': (2, 1)
169+
}),
170+
(layers.RepeatVector, {
171+
'n': 3
172+
}),
173+
(layers.Reshape, {
174+
'target_shape': [5, 1, 1]
175+
}),
176+
(layers.SpatialDropout1D, {
177+
'input_shape': (3, 6),
178+
'rate': 0.2,
179+
}),
180+
(layers.SpatialDropout2D, {
181+
'input_shape': (4, 6, 1),
182+
'rate': 0.2,
183+
}),
184+
(layers.SpatialDropout3D, {
185+
'input_shape': (3, 4, 6, 1),
186+
'rate': 0.2,
187+
}),
188+
(layers.AveragePooling1D, {
189+
'input_shape': (3, 6),
190+
}),
191+
(layers.AveragePooling2D, {
192+
'input_shape': (4, 6, 1),
193+
}),
194+
(layers.AveragePooling3D, {
195+
'input_shape': (3, 4, 6, 1),
196+
}),
197+
(layers.GlobalAveragePooling1D, {
198+
'input_shape': (3, 6),
199+
}),
200+
(layers.GlobalAveragePooling2D, {
201+
'input_shape': (4, 6, 1),
202+
}),
203+
(layers.GlobalAveragePooling3D, {
204+
'input_shape': (3, 4, 6, 1),
205+
}),
206+
(layers.GlobalMaxPooling1D, {
207+
'input_shape': (3, 6),
208+
}),
209+
(layers.GlobalMaxPooling2D, {
210+
'input_shape': (4, 6, 1),
211+
}),
212+
(layers.GlobalMaxPooling3D, {
213+
'input_shape': (3, 4, 6, 1),
214+
}),
215+
(layers.MaxPooling1D, {
216+
'input_shape': (3, 6),
217+
}),
218+
(layers.MaxPooling2D, {
219+
'input_shape': (4, 6, 1),
220+
}),
221+
(layers.MaxPooling3D, {
222+
'input_shape': (3, 4, 6, 1),
223+
}),
224+
# LocallyConnected1D implementations use significantly different TF
225+
# operations underneath, so they should be all tested.
226+
(layers.LocallyConnected1D, {
227+
'input_shape': (3, 6),
228+
'implementation': 1,
229+
'filters': 4,
230+
'kernel_size': 2
231+
}),
232+
(layers.LocallyConnected1D, {
233+
'input_shape': (3, 6),
234+
'implementation': 2,
235+
'filters': 4,
236+
'kernel_size': 2
237+
}),
238+
(layers.LocallyConnected1D, {
239+
'input_shape': (3, 6),
240+
'implementation': 3,
241+
'filters': 4,
242+
'kernel_size': 2
243+
}),
244+
(layers.LocallyConnected2D, {
245+
'input_shape': (4, 6, 1),
246+
'implementation': 1,
247+
'filters': 4,
248+
'kernel_size': (2, 2)
249+
}),
250+
(layers.LocallyConnected2D, {
251+
'input_shape': (4, 6, 1),
252+
'implementation': 2,
253+
'filters': 4,
254+
'kernel_size': (2, 2)
255+
}),
256+
(layers.LocallyConnected2D, {
257+
'input_shape': (4, 6, 1),
258+
'implementation': 3,
259+
'filters': 4,
260+
'kernel_size': (2, 2)
261+
}),
262+
]
263+
264+
# pylint: disable=g-complex-comprehension,undefined-variable
265+
266+
@parameterized.parameters([
267+
l for l in _LAYER_PARAMS if l[0] not in [
268+
# Not done since TFLite converter doesn't support in TF2 yet.
269+
layers.UpSampling2D,
270+
layers.Conv3D,
271+
layers.Conv3DTranspose,
272+
layers.AveragePooling3D,
273+
layers.MaxPooling3D,
274+
layers.LocallyConnected1D,
275+
layers.LocallyConnected2D,
276+
# Not done since TFLite inference doesn't support yet.
277+
layers.ZeroPadding3D, # Does not support 5D inputs yet.
278+
# Not done because converter transforms graph until there are
279+
# zero ops, and then an error is thrown because it cannot handle
280+
# zero op graphs.
281+
layers.ActivityRegularization,
282+
layers.Dropout,
283+
layers.Flatten,
284+
layers.SpatialDropout1D,
285+
layers.SpatialDropout2D,
286+
layers.SpatialDropout3D,
287+
# Not done since there are float tensors besides
288+
# the inputs and outputs (e.g. FakeQuant not placed in
289+
# all areas or converter support not there).
290+
layers.Masking,
291+
layers.RepeatVector,
292+
layers.MaxPooling1D,
293+
layers.UpSampling1D,
294+
layers.UpSampling3D,
295+
]
296+
])
297+
def testQuantizeSingleLayer_ProducesFullIntegerModel_TF2(
298+
self, layer_type, kwargs):
299+
# "FullInteger" in the sense that ignores inputs and outputs.
300+
if compat.is_v1_apis():
301+
return
302+
303+
if 'input_shape' not in kwargs:
304+
kwargs['input_shape'] = (5,)
305+
306+
layer = layer_type(**kwargs)
307+
model = tf.keras.Sequential([layer])
308+
quantized_model = quantize.quantize_model(model)
309+
310+
_, quantized_tflite_file = tempfile.mkstemp('.tflite')
311+
312+
with quantize.quantize_scope():
313+
test_utils.convert_keras_to_tflite(
314+
model=quantized_model,
315+
output_path=quantized_tflite_file,
316+
is_quantized=True,
317+
input_quant_params=(0., 1.),
318+
experimental_new_converter=True)
319+
320+
interpreter = tf.lite.Interpreter(model_path=quantized_tflite_file)
321+
interpreter.allocate_tensors()
322+
323+
input_tensor_details = interpreter.get_input_details()
324+
self.assertEqual(input_tensor_details[0]['dtype'], np.float32)
325+
326+
output_tensor_details = interpreter.get_output_details()
327+
self.assertEqual(output_tensor_details[0]['dtype'], np.float32)
328+
329+
tensor_details = interpreter.get_tensor_details()
330+
float_tensor_details = [
331+
t for t in tensor_details if t['dtype'] == np.float32
332+
]
333+
# Only the input and outputs are float. The rest are integer.
334+
#
335+
# TODO(tfmot): update this test to use the full-integer path when available,
336+
# so that float_tensor_details should be length 0.
337+
self.assertLen(float_tensor_details, 2)
338+
339+
# This unit test runs in TF1. While we don't publicly support this path in
340+
# the Keras tooling, this is useful for two reasons:
341+
# 1. TOCO has better debugging functionality than MLIR, for incrementally
342+
# adding new layers.
343+
# 2. It's useful to track supported layers in TF1 converter in case we
344+
# want to eventually support V1 conversion.
345+
# 3. This also tracks more layers where FakeQuant placement is incorrect,
346+
# given that the TF2 converter doesn't support all layers that TF1 did.
347+
@parameterized.parameters([
348+
l for l in _LAYER_PARAMS if l[0] not in [
349+
# Not done since per-channel not supported in TF1 without MLIR.
350+
# By temporarily switching layers to be per-tensor instead of
351+
# per-channel, some minimum testing can be done.
352+
#
353+
# TODO(tfmot): add Conv1D/Conv3D/Conv with Transpose after they
354+
# are made per-channel by quantization scheme.
355+
layers.Conv2D,
356+
# Not done since FakeQuants are not placed in right areas or
357+
# converter doesn't handle it properly yet.
358+
layers.Conv3D,
359+
layers.Conv3DTranspose,
360+
layers.Masking,
361+
layers.LocallyConnected1D,
362+
# TODO(tfmot): find reason.
363+
layers.LocallyConnected2D,
364+
# Not done because TF1 converter doesn't support quantized op.
365+
layers.AveragePooling3D,
366+
layers.MaxPooling3D,
367+
# Not done because TF1 converter transforms graph until there are
368+
# zero ops, and then an error is thrown because it cannot handle
369+
# zero op graphs.
370+
layers.ActivityRegularization,
371+
layers.Dropout,
372+
layers.Flatten,
373+
layers.SpatialDropout1D,
374+
layers.SpatialDropout2D,
375+
layers.SpatialDropout3D,
376+
]
377+
])
378+
def testQuantizeSingleLayer_ProducesFullIntegerModel_TF1(
379+
self, layer_type, kwargs):
380+
if not compat.is_v1_apis():
381+
return
382+
383+
if 'input_shape' not in kwargs:
384+
kwargs['input_shape'] = (5,)
385+
386+
layer = layer_type(**kwargs)
387+
model = tf.keras.Sequential([layer])
388+
quantized_model = quantize.quantize_model(model)
389+
390+
with quantize.quantize_scope():
391+
test_utils.convert_keras_to_tflite(
392+
model=quantized_model,
393+
output_path=None,
394+
is_quantized=True,
395+
inference_type=tf.uint8,
396+
inference_input_type=tf.uint8,
397+
input_quant_params=(0., 1.),
398+
# Set to False to throw errors when FakeQuants are
399+
# not placed everywhere to create full-integer model. Errors
400+
# are not thrown when set to True.
401+
experimental_new_converter=False)
402+
403+
# pylint: enable=g-complex-comprehension,undefined-variable
404+
405+
83406
if __name__ == '__main__':
84407
tf.test.main()

tensorflow_model_optimization/python/core/quantization/keras/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ def convert_keras_to_tflite(model,
2626
output_path,
2727
custom_objects=None,
2828
is_quantized=True,
29+
inference_type=None,
2930
inference_input_type=None,
30-
input_quant_params=(-128., 255.)):
31+
input_quant_params=(-128., 255.),
32+
experimental_new_converter=True):
3133
"""Convert Keras model to TFLite."""
3234
if custom_objects is None:
3335
custom_objects = {}
@@ -40,14 +42,18 @@ def convert_keras_to_tflite(model,
4042
converter = tf.lite.TFLiteConverter.from_keras_model_file(
4143
keras_file, custom_objects=custom_objects)
4244

43-
converter.experimental_new_converter = True
45+
converter.experimental_new_converter = experimental_new_converter
4446

4547
if is_quantized:
4648
if not compat.is_v1_apis():
4749
converter.optimizations = [tf.lite.Optimize.DEFAULT]
4850
else:
4951
converter.inference_type = tf.lite.constants.INT8
5052
converter.inference_input_type = tf.lite.constants.FLOAT
53+
# TODO(tfmot): should be able to make everything use the
54+
# same inference_type in TF 1.X tests.
55+
if inference_type:
56+
converter.inference_type = inference_type
5157
if inference_input_type:
5258
converter.inference_input_type = inference_input_type
5359

0 commit comments

Comments
 (0)