Skip to content

Commit 6b8edcd

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Improve pydoc for Quantizer
PiperOrigin-RevId: 305288335
1 parent bf67a58 commit 6b8edcd

File tree

1 file changed

+47
-5
lines changed
  • tensorflow_model_optimization/python/core/quantization/keras

1 file changed

+47
-5
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quantizers.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,50 @@
3333

3434
@six.add_metaclass(abc.ABCMeta)
3535
class Quantizer(object):
36-
"""ABC interface which contains logic to quantize a tensor."""
36+
"""ABC interface which encapsulates the logic of how to quantize tensors.
37+
38+
A `Quantizer` is used by the library code to apply the mathematical
39+
transformations which actually quantize a tensor, hence allowing the user
40+
precise control over the algorithm with which tensors are quantized. When used
41+
in conjunction with `QuantizeConfig` it controls how a layer is quantized.
42+
43+
Create a custom quantizer:
44+
45+
```python
46+
class FixedRangeQuantizer(Quantizer):
47+
# Example quantizer which clips tensors in a fixed range.
48+
49+
def build(self, tensor_shape, name, layer):
50+
range_var = layer.add_weight(
51+
name + '_range',
52+
initializer=keras.initializers.Constant(6.0),
53+
trainable=False)
54+
55+
return {
56+
'range_var': range_var,
57+
}
58+
59+
def __call__(self, inputs, training, weights, **kwargs):
60+
return tf.keras.backend.clip(
61+
inputs, 0.0, weights['range_var'])
62+
63+
def get_config(self):
64+
# Not needed. No __init__ parameters to serialize.
65+
return {}
66+
```
67+
68+
For a full example, see
69+
https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.md
70+
"""
3771

3872
@abc.abstractmethod
3973
def build(self, tensor_shape, name, layer):
40-
"""Constructs the weights required by the quantizer.
74+
"""Construct the weights required by the quantizer.
75+
76+
A quantizer may need to construct variables to hold the state for its
77+
algorithm. This function is invoked during the `build` stage of the layer
78+
that the quantizer is used for. Any variables constructed are under the
79+
scope of the `layer` and serialized as part of the layer.
4180
4281
Args:
4382
tensor_shape: Shape of tensor which needs to be quantized.
@@ -46,27 +85,30 @@ def build(self, tensor_shape, name, layer):
4685
to construct the weights, and is also the owner of the weights.
4786
4887
Returns: Dictionary of constructed weights. This dictionary will be
49-
unpacked and passed to the quantizer's __call__ function as kwargs.
88+
passed to the quantizer's __call__ function as a `weights` dictionary.
5089
"""
5190

5291
@abc.abstractmethod
5392
def __call__(self, inputs, training, weights, **kwargs):
5493
"""Apply quantization to the input tensor.
5594
56-
The `step` variable allows a user to design a custom quantizer which
57-
modifies quantization behavior as training progresses.
95+
This is the main function of the `Quantizer` which implements the core logic
96+
to quantize the tensor. It is invoked during the `call` stage of the layer,
97+
and allows modifying the tensors used in graph construction.
5898
5999
Args:
60100
inputs: Input tensor to be quantized.
61101
training: Whether the graph is currently training.
62102
weights: Dictionary of weights the quantizer can use to quantize the
63103
tensor. This contains the weights created in the `build` function.
64104
**kwargs: Additional variables which may be passed to the quantizer.
105+
65106
Returns: quantized tensor.
66107
"""
67108

68109
@abc.abstractmethod
69110
def get_config(self):
111+
"""Returns the config used to serialize the `Quantizer`."""
70112
raise NotImplementedError('Quantizer should implement get_config().')
71113

72114
@classmethod

0 commit comments

Comments
 (0)