Skip to content

Commit 70494e3

Browse files
alanchiaotensorflower-gardener
authored andcommitted
- Update QAT docs to clarify support.
- Update QAT tutorials based on 0.3.0 release code and clarify custom Keras layers use case. PiperOrigin-RevId: 305277833
1 parent a5d7780 commit 70494e3

File tree

3 files changed

+25
-16
lines changed

3 files changed

+25
-16
lines changed

tensorflow_model_optimization/g3doc/guide/quantization/training.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ leading to benefits during deployment.
2424

2525
Quantization brings improvements via model compression and latency reduction.
2626
With the API defaults, the model size shrinks by 4x, and we typically see
27-
between 1.5 - 4x improvements in CPU latency in the tested backends. Further
27+
between 1.5 - 4x improvements in CPU latency in the tested backends. Eventually,
2828
latency improvements can be seen on compatible machine learning accelerators,
2929
such as the [EdgeTPU](https://coral.ai/docs/edgetpu/benchmarks/) and NNAPI.
3030

3131
The technique is used in production in speech, vision, text, and translate use
32-
cases. The code currently supports vision use cases and will expand over time.
32+
cases. The code currently supports a subset of these models.
3333

3434
#### Experiment with quantization and associated hardware
3535

@@ -62,10 +62,12 @@ Support is available in the following areas:
6262

6363
* Model coverage: models using
6464
[whitelisted layers](https://github.com/tensorflow/model-optimization/tree/master/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry.py),
65-
BatchNormalization, and in limited cases, Concat.
65+
BatchNormalization when it follows a convolutional or `Dense` layer, and in
66+
limited cases, `Concat`.
6667
<!-- TODO(tfmot): add more details and ensure they are all correct. -->
6768
* Hardware acceleration: our API defaults are compatible with acceleration on
68-
EdgeTPU, NNAPI, and TFLite backends, amongst others.
69+
EdgeTPU, NNAPI, and TFLite backends, amongst others. See the caveat in the
70+
roadmap.
6971
* Deploy with quantization: only per-axis quantization for convolutional
7072
layers, not per-tensor quantization, is currently supported.
7173

@@ -75,6 +77,9 @@ It is on our roadmap to add support in the following areas:
7577
to launch. -->
7678

7779
* Model coverage: extended to include RNN/LSTMs and general Concat support.
80+
* Hardware acceleration: ensure the TFLite converter can produce full-integer
81+
models. See [this
82+
issue](https://github.com/tensorflow/tensorflow/issues/38285) for details.
7883
* Experiment with quantization use cases:
7984
* Experiment with quantization algorithms that span Keras layers or
8085
require the training step.

tensorflow_model_optimization/g3doc/guide/quantization/training_comprehensive_guide.ipynb

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@
102102
"outputs": [],
103103
"source": [
104104
"! pip uninstall -y tensorflow\n",
105-
"! pip install -q tf-nightly==2.2.0.dev20200315\n",
106-
"! pip install -q --extra-index-url=https://test.pypi.org/simple/ tensorflow-model-optimization==0.3.0.dev6\n",
105+
"! pip install -q tf-nightly\n",
106+
"! pip install -q tensorflow-model-optimization\n",
107107
"\n",
108108
"import tensorflow as tf\n",
109109
"import numpy as np\n",
@@ -610,12 +610,11 @@
610610
"id": "YmyhI_bzWb2w"
611611
},
612612
"source": [
613-
"This example uses the `DefaultDenseQuantizeConfig` to quantize a `Dense` layer. In practice, the layer\n",
614-
"can be any custom Keras layer.\n",
613+
"This example uses the `DefaultDenseQuantizeConfig` to quantize the `CustomLayer`.\n",
615614
"\n",
616615
"Applying the configuration is the same across\n",
617616
"the \"Experiment with quantization\" use cases.\n",
618-
" * Apply `tfmot.quantization.keras.quantize_annotate_layer` to the `Dense` layer and pass in the `QuantizeConfig`.\n",
617+
" * Apply `tfmot.quantization.keras.quantize_annotate_layer` to the `CustomLayer` and pass in the `QuantizeConfig`.\n",
619618
" * Use\n",
620619
"`tfmot.quantization.keras.quantize_annotate_model` to continue to quantize the rest of the model with the API defaults.\n",
621620
"\n"
@@ -635,14 +634,19 @@
635634
"quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model\n",
636635
"quantize_scope = tfmot.quantization.keras.quantize_scope\n",
637636
"\n",
637+
"class CustomLayer(tf.keras.layers.Dense):\n",
638+
" pass\n",
639+
"\n",
638640
"model = quantize_annotate_model(tf.keras.Sequential([\n",
639-
" quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), DefaultDenseQuantizeConfig()),\n",
641+
" quantize_annotate_layer(CustomLayer(20, input_shape=(20,)), DefaultDenseQuantizeConfig()),\n",
640642
" tf.keras.layers.Flatten()\n",
641643
"]))\n",
642644
"\n",
643-
"# `quantize_apply` requires mentioning `DefaultDenseQuantizeConfig` with `quantize_scope`:\n",
645+
"# `quantize_apply` requires mentioning `DefaultDenseQuantizeConfig` with `quantize_scope`\n",
646+
"# as well as the custom Keras layer.\n",
644647
"with quantize_scope(\n",
645-
" {'DefaultDenseQuantizeConfig': DefaultDenseQuantizeConfig}):\n",
648+
" {'DefaultDenseQuantizeConfig': DefaultDenseQuantizeConfig,\n",
649+
" 'CustomLayer': CustomLayer}):\n",
646650
" # Use `quantize_apply` to actually make the model quantization aware.\n",
647651
" quant_aware_model = tfmot.quantization.keras.quantize_apply(model)\n",
648652
"\n",
@@ -864,7 +868,7 @@
864868
" # Not needed. No new TensorFlow variables needed.\n",
865869
" return {}\n",
866870
"\n",
867-
" def __call__(self, inputs, step, training, **kwargs):\n",
871+
" def __call__(self, inputs, training, weights, **kwargs):\n",
868872
" return tf.keras.backend.clip(inputs, -1.0, 1.0)\n",
869873
"\n",
870874
" def get_config(self):\n",

tensorflow_model_optimization/g3doc/guide/quantization/training_example.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@
118118
"outputs": [],
119119
"source": [
120120
"! pip uninstall -y tensorflow\n",
121-
"! pip install -q tf-nightly==2.2.0.dev20200305\n",
122-
"! pip install -q --extra-index-url=https://test.pypi.org/simple/ tensorflow-model-optimization==0.3.0.dev6\n"
121+
"! pip install -q tf-nightly\n",
122+
"! pip install -q tensorflow-model-optimization\n"
123123
]
124124
},
125125
{
@@ -222,7 +222,7 @@
222222
"\n",
223223
"Note that the resulting model is quantization aware but not quantized (e.g. the weights are float32 instead of int8). The sections after show how to create a quantized model from the quantization aware one.\n",
224224
"\n",
225-
"In the [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.ipynb), you can see how to quantize some layers for model accuracy improvements."
225+
"In the [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.md), you can see how to quantize some layers for model accuracy improvements."
226226
]
227227
},
228228
{

0 commit comments

Comments
 (0)