Skip to content

Commit b0ab25f

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Improve API pydocs for QAT API
PiperOrigin-RevId: 305127202
1 parent de7187c commit b0ab25f

File tree

5 files changed

+192
-60
lines changed

5 files changed

+192
-60
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quantize.py

Lines changed: 101 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,12 @@
3030

3131

3232
def quantize_scope(*args):
33-
"""Required scope to deserialize quantized models stored in tf.keras h5 format.
33+
"""Scope which can be used to deserialize quantized Keras models and layers.
3434
35-
Args:
36-
*args: Variable length list of dictionaries of name, class pairs to add to
37-
the scope created by this method.
38-
39-
Returns:
40-
Object of type `CustomObjectScope` with quantization objects included.
35+
Under `quantize_scope`, Keras methods such as `tf.keras.load_model` and
36+
`tf.keras.models.model_from_config` will be able to deserialize Keras models
37+
and layers which contain quantization classes such as `QuantizeConfig`
38+
and `Quantizer`.
4139
4240
Example:
4341
@@ -46,7 +44,21 @@ def quantize_scope(*args):
4644
4745
with quantize_scope():
4846
loaded_model = tf.keras.models.load_model(keras_file)
47+
48+
# If your quantized model uses custom objects such as a specific `Quantizer`,
49+
# you can pass them to quantize_scope to deserialize your model.
50+
with quantize_scope({'FixedRangeQuantizer', FixedRangeQuantizer}
51+
loaded_model = tf.keras.models.load_model(keras_file)
4952
```
53+
54+
For further understanding, see `tf.keras.utils.custom_object_scope`.
55+
56+
Args:
57+
*args: Variable length list of dictionaries of `{name, class}` pairs to add
58+
to the scope created by this method.
59+
60+
Returns:
61+
Object of type `CustomObjectScope` with quantization objects included.
5062
"""
5163
quantization_objects = {
5264
'QuantizeAnnotate': quantize_annotate_mod.QuantizeAnnotate,
@@ -65,34 +77,46 @@ def quantize_scope(*args):
6577
return tf.keras.utils.custom_object_scope(*(args + (quantization_objects,)))
6678

6779

68-
# TODO(tfmot): link to docs to explain what quantization implementation means.
6980
def quantize_model(to_quantize):
70-
"""Quantize a whole tf.keras model with the default quantization implementation.
81+
"""Quantize a `tf.keras` model with the default quantization implementation.
82+
83+
Quantization constructs a model which emulates quantization during training.
84+
This allows the model to learn parameters robust to quantization loss, and
85+
also model the accuracy of a quantized model.
7186
72-
To be more precise, `quantize_model` creates a model that emulates
73-
quantization during training and stores information that downstream
74-
tools will use to produce actually quantized models.
87+
For more information, see
88+
https://www.tensorflow.org/model_optimization/guide/quantization/training
7589
7690
Quantize a model:
7791
7892
```python
93+
# Quantize sequential model
7994
model = quantize_model(
8095
keras.Sequential([
8196
layers.Dense(10, activation='relu', input_shape=(100,)),
8297
layers.Dense(2, activation='sigmoid')
8398
]))
99+
100+
# Quantize functional model
101+
in = tf.keras.Input((3,))
102+
out = tf.keras.Dense(2)(in)
103+
model = tf.keras.Model(in, out)
104+
105+
quantized_model = quantize_model(model)
84106
```
85107
86108
Note that this function removes the optimizer from the original model.
87-
Additionally, training the model returned by `quantize_model` will not affect
88-
the weights of the original model.
109+
110+
The returned model copies over weights from the original model. So while
111+
it preserves the original weights, training it will not modify the weights
112+
of the original model.
89113
90114
Args:
91115
to_quantize: tf.keras model to be quantized. It can have pre-trained
92116
weights.
93117
94118
Returns:
95-
Returns a new tf.keras model prepared for quantization.
119+
Returns a new `tf.keras` model prepared for quantization.
96120
"""
97121
if to_quantize is None:
98122
raise ValueError('`to_quantize` cannot be None')
@@ -115,35 +139,43 @@ def quantize_model(to_quantize):
115139

116140

117141
def quantize_annotate_model(to_annotate):
118-
"""Annotate a model to be quantized.
142+
"""Annotate a `tf.keras` model to be quantized.
119143
120-
This function does not actually quantize anything. It is merely to specify
121-
that the model needs to be quantized.
144+
This function does not actually quantize the model. It merely specifies
145+
that the model needs to be quantized. `quantize_apply` can then be used
146+
to quantize the model.
122147
123148
This function is intended to be used in conjunction with the
124-
`quantize_annotate_layer`
125-
API. It's otherwise simpler to use `quantize_model`.
149+
`quantize_annotate_layer` API. Otherwise, it is simpler to use
150+
`quantize_model`.
126151
127-
Annotate a model while overriding the default behavior for one layer:
152+
Annotate a model while overriding the default behavior for a layer:
128153
129154
```python
130155
quantize_config = MyDenseQuantizeConfig()
131156
132-
model = quantize_annotate_model(keras.Sequential([
157+
model = quantize_annotate_model(
158+
keras.Sequential([
133159
layers.Dense(10, activation='relu', input_shape=(100,)),
134-
quantize_annotate_layer(layers.Dense(2, activation='sigmoid'),
135-
quantize_config=quantize_config)
136-
])))
160+
quantize_annotate_layer(
161+
layers.Dense(2, activation='sigmoid'),
162+
quantize_config=quantize_config)
163+
]))
164+
165+
# The first Dense layer gets quantized with the default behavior,
166+
# but the second layer uses `MyDenseQuantizeConfig` for quantization.
167+
quantized_model = quantize_apply(model)
137168
```
138169
139170
Note that this function removes the optimizer from the original model.
140171
141172
Args:
142-
to_annotate: tf.keras model to annotate to be quantized.
173+
to_annotate: `tf.keras` model which needs to be quantized.
143174
144175
Returns:
145176
New tf.keras model with each layer in the model wrapped with
146-
`QuantizeAnnotate`.
177+
`QuantizeAnnotate`. The new model preserves weights from the original
178+
model.
147179
"""
148180
if to_annotate is None:
149181
raise ValueError('`to_annotate` cannot be None')
@@ -179,28 +211,35 @@ def _add_quant_wrapper(layer):
179211

180212

181213
def quantize_annotate_layer(to_annotate, quantize_config=None):
182-
"""Annotate a layer to be quantized.
214+
"""Annotate a `tf.keras` layer to be quantized.
215+
216+
This function does not actually quantize the layer. It is merely used to
217+
specify that the layer should be quantized. The layer then gets quantized
218+
accordingly when `quantize_apply` is used.
183219
184-
This function does not actually quantize anything. It is merely to specify
185-
that the layer needs to be quantized.
220+
This method should be used when the user wants to quantize only certain
221+
layers of the model, or change the default behavior of how a layer is
222+
quantized.
186223
187224
Annotate a layer:
188225
189226
```python
190227
model = keras.Sequential([
191228
layers.Dense(10, activation='relu', input_shape=(100,)),
192229
quantize_annotate_layer(layers.Dense(2, activation='sigmoid'))
193-
]))
194-
```
230+
])
195231
196-
Note that this function removes the optimizer from the original model.
232+
# Only the second Dense layer is quantized.
233+
quantized_model = quantize_apply(model)
234+
```
197235
198236
Args:
199-
to_annotate: tf.keras layer to annotate to be quantized.
200-
quantize_config: `QuantizeConfig` to quantize layer.
237+
to_annotate: `tf.keras` layer which needs to be quantized.
238+
quantize_config: optional `QuantizeConfig` which controls how the layer is
239+
quantized. In its absence, the default behavior for the layer is used.
201240
202241
Returns:
203-
tf.keras layer wrapped with `QuantizeAnnotate`.
242+
`tf.keras` layer wrapped with `QuantizeAnnotate`.
204243
"""
205244
if to_annotate is None:
206245
raise ValueError('`to_annotate` cannot be None')
@@ -225,31 +264,43 @@ def quantize_annotate_layer(to_annotate, quantize_config=None):
225264

226265

227266
def quantize_apply(model):
228-
"""Introduce quantization operations to a tf.keras model.
267+
"""Quantize a `tf.keras` model.
229268
230-
This function takes a tf.keras model which has been annotated with
231-
`quantize_annotate` and constructs a new model in which each of the
232-
annotated layers will ultimately be quantized. The new quantization
233-
operations enable the model to **emulate* quantization during training
234-
and store information that downstream tools will use to produce
235-
an actually quantized model.
269+
Quantization constructs a model which emulates quantization during training.
270+
This allows the model to learn parameters robust to quantization loss, and
271+
also model the accuracy of a quantized model.
236272
237-
Apply quantization to a model:
273+
For more information, see
274+
https://www.tensorflow.org/model_optimization/guide/quantization/training
275+
TODO(tfmot): Link blog once launched.
238276
277+
This function takes a `tf.keras` model in which the desired layers for
278+
quantization have already been annotated. See `quantize_annotate_model`
279+
and `quantize_annotate_layer`.
280+
281+
Quantize model.
239282
```python
240-
model = quantize_apply(annotated_model)
283+
model = keras.Sequential([
284+
layers.Dense(10, activation='relu', input_shape=(100,)),
285+
quantize_annotate_layer(layers.Dense(2, activation='sigmoid'))
286+
])
287+
288+
# Only the second Dense layer is quantized.
289+
quantized_model = quantize_apply(model)
241290
```
242291
243292
Note that this function removes the optimizer from the original model.
244-
Additionally, training the model returned by `quantize_apply` will not affect
245-
the weights of the original model.
293+
294+
The returned model copies over weights from the original model. So while
295+
it preserves the original weights, training it will not modify the weights
296+
of the original model.
246297
247298
Args:
248-
model: A tf.keras Sequential or Functional model which has been annotated
249-
with `quantize_annotate`. It can have pre-trained weights.
299+
model: A `tf.keras` Sequential or Functional model which has been annotated
300+
with `quantize_annotate`. It can have pre-trained weights.
250301
251302
Returns:
252-
Returns a new tf.keras model in which the annotated layers have been
303+
Returns a new `tf.keras` model in which the annotated layers have been
253304
prepared for quantization.
254305
"""
255306
if model is None:

tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(self, layer, quantize_config=None, **kwargs):
4747
4848
Args:
4949
layer: The keras layer to be quantized.
50-
quantize_config: `QuantizeConfig` to quantize layer.
50+
quantize_config: Optional `QuantizeConfig` to quantize the layer.
5151
**kwargs: Additional keyword arguments to be passed to the keras layer.
5252
"""
5353
super(QuantizeAnnotate, self).__init__(layer, **kwargs)

0 commit comments

Comments
 (0)