33
33
34
34
@six .add_metaclass (abc .ABCMeta )
35
35
class Quantizer (object ):
36
- """ABC interface which contains logic to quantize a tensor."""
36
+ """ABC interface which encapsulates the logic of how to quantize tensors.
37
+
38
+ A `Quantizer` is used by the library code to apply the mathematical
39
+ transformations which actually quantize a tensor, hence allowing the user
40
+ precise control over the algorithm with which tensors are quantized. When used
41
+ in conjunction with `QuantizeConfig` it controls how a layer is quantized.
42
+
43
+ Create a custom quantizer:
44
+
45
+ ```python
46
+ class FixedRangeQuantizer(Quantizer):
47
+ # Example quantizer which clips tensors in a fixed range.
48
+
49
+ def build(self, tensor_shape, name, layer):
50
+ range_var = layer.add_weight(
51
+ name + '_range',
52
+ initializer=keras.initializers.Constant(6.0),
53
+ trainable=False)
54
+
55
+ return {
56
+ 'range_var': range_var,
57
+ }
58
+
59
+ def __call__(self, inputs, training, weights, **kwargs):
60
+ return tf.keras.backend.clip(
61
+ inputs, 0.0, weights['range_var'])
62
+
63
+ def get_config(self):
64
+ # Not needed. No __init__ parameters to serialize.
65
+ return {}
66
+ ```
67
+
68
+ For a full example, see
69
+ https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.md
70
+ """
37
71
38
72
@abc .abstractmethod
39
73
def build (self , tensor_shape , name , layer ):
40
- """Constructs the weights required by the quantizer.
74
+ """Construct the weights required by the quantizer.
75
+
76
+ A quantizer may need to construct variables to hold the state for its
77
+ algorithm. This function is invoked during the `build` stage of the layer
78
+ that the quantizer is used for. Any variables constructed are under the
79
+ scope of the `layer` and serialized as part of the layer.
41
80
42
81
Args:
43
82
tensor_shape: Shape of tensor which needs to be quantized.
@@ -46,27 +85,30 @@ def build(self, tensor_shape, name, layer):
46
85
to construct the weights, and is also the owner of the weights.
47
86
48
87
Returns: Dictionary of constructed weights. This dictionary will be
49
- unpacked and passed to the quantizer's __call__ function as kwargs .
88
+ passed to the quantizer's __call__ function as a `weights` dictionary .
50
89
"""
51
90
52
91
@abc .abstractmethod
53
92
def __call__ (self , inputs , training , weights , ** kwargs ):
54
93
"""Apply quantization to the input tensor.
55
94
56
- The `step` variable allows a user to design a custom quantizer which
57
- modifies quantization behavior as training progresses.
95
+ This is the main function of the `Quantizer` which implements the core logic
96
+ to quantize the tensor. It is invoked during the `call` stage of the layer,
97
+ and allows modifying the tensors used in graph construction.
58
98
59
99
Args:
60
100
inputs: Input tensor to be quantized.
61
101
training: Whether the graph is currently training.
62
102
weights: Dictionary of weights the quantizer can use to quantize the
63
103
tensor. This contains the weights created in the `build` function.
64
104
**kwargs: Additional variables which may be passed to the quantizer.
105
+
65
106
Returns: quantized tensor.
66
107
"""
67
108
68
109
@abc .abstractmethod
69
110
def get_config (self ):
111
+ """Returns the config used to serialize the `Quantizer`."""
70
112
raise NotImplementedError ('Quantizer should implement get_config().' )
71
113
72
114
@classmethod
0 commit comments