19
19
from __future__ import print_function
20
20
21
21
from tensorflow .python .keras import activations
22
- from tensorflow .python .keras import backend as K
23
22
from tensorflow .python .keras import initializers
24
- from tensorflow .python .keras .layers import Layer
25
23
26
- from tensorflow_model_optimization .python .core .quantization .keras import quant_ops
27
24
28
-
29
- class QuantizeAwareActivation (Layer ):
25
+ class QuantizeAwareActivation (object ):
30
26
"""Activation layer for quantization aware training.
31
27
32
28
The goal of this layer is to apply quantize operations during training such
@@ -51,107 +47,70 @@ class QuantizeAwareActivation(Layer):
51
47
52
48
_PRE_ACTIVATION_TYPES = {'softmax' }
53
49
54
- def __init__ (
55
- self ,
56
- activation ,
57
- parent_layer ,
58
- num_bits ,
59
- symmetric = True ,
60
- ** kwargs ):
50
+ def __init__ (self , activation , quantizer , step , quantize_wrapper ):
61
51
"""Construct a QuantizeAwareActivation layer.
62
52
63
53
Args:
64
- activation: Activation function to use.
65
- If you don't specify anything, no activation is applied
54
+ activation: Activation function to use. If you don't specify anything, no
55
+ activation is applied
66
56
(ie. "linear" activation: `a(x) = x`).
67
- parent_layer: The layer this activation is being applied to. Such
68
- as Conv2D, Dense etc.
69
- num_bits: Number of bits for quantization
70
- symmetric: If true, use symmetric quantization limits instead of training
71
- the minimum and maximum of each quantization range separately.
72
- **kwargs: Additional keyword arguments to be passed to the keras layer.
57
+ quantizer: `Quantizer` to be used to quantize the activation.
58
+ step: Variable which tracks optimizer step.
59
+ quantize_wrapper: `QuantizeWrapper` which owns this activation.
73
60
"""
74
- super (QuantizeAwareActivation , self ).__init__ (** kwargs )
75
-
76
61
self .activation = activations .get (activation )
77
- self .parent_layer = parent_layer
62
+ self .quantizer = quantizer
63
+ self .step = step
64
+ self .quantize_wrapper = quantize_wrapper
65
+
66
+ self ._training = False
78
67
79
- self .num_bits = num_bits
80
- self .symmetric = symmetric
68
+ if self ._should_pre_quantize ():
69
+ self ._min_pre_activation , self ._max_pre_activation = \
70
+ self ._add_range_weights ('pre_activation' )
81
71
82
- # TODO(pulkitb): Generate a meaningful name for this layer, which
83
- # ideally also includes the parent layer.
72
+ self . _min_post_activation , self . _max_post_activation = \
73
+ self . _add_range_weights ( 'post_activation' )
84
74
85
- def _requires_pre_quant (self ):
86
- # TODO(pulkitb): Make this more sophisticated. This should match the
87
- # implementation of kernels on-device.
75
+ def _should_pre_quantize (self ):
76
+ # TODO(pulkitb): Add logic to deduce whether we should pre-quantize.
77
+ # Whether we apply quantize operations around activations depends on the
78
+ # implementation of the specific kernel. For example, ReLUs are fused in
79
+ # whereas Softmax ops are not. Should linear have post-quantize?
88
80
return self .activation .__name__ in self ._PRE_ACTIVATION_TYPES
89
81
90
- def build (self , input_shape ):
91
- if self ._requires_pre_quant ():
92
- self ._min_pre_activation = self .add_variable (
93
- 'min_pre_activation' ,
94
- initializer = initializers .Constant (- 6.0 ),
95
- trainable = False )
96
- self ._max_pre_activation = self .add_variable (
97
- 'max_pre_activation' ,
98
- initializer = initializers .Constant (6.0 ),
99
- trainable = False )
100
-
101
- self ._min_post_activation = self .add_variable (
102
- 'min_post_activation' ,
103
- initializer = initializers .Constant (- 6.0 ),
104
- trainable = False )
105
- self ._max_post_activation = self .add_variable (
106
- 'max_post_activation' ,
107
- initializer = initializers .Constant (6.0 ),
108
- trainable = False )
109
-
110
- def call (self , inputs , training = None ):
111
- # TODO(pulkitb): Construct graph for both training/eval modes.
112
- if training is None :
113
- training = K .learning_phase ()
82
+ def _add_range_weights (self , name ):
83
+ min_var = self .quantize_wrapper .add_weight (
84
+ name + '_min' , initializer = initializers .Constant (- 6.0 ), trainable = False )
85
+ max_var = self .quantize_wrapper .add_weight (
86
+ name + '_max' , initializer = initializers .Constant (6.0 ), trainable = False )
87
+
88
+ return min_var , max_var
89
+
90
+ @property
91
+ def training (self ):
92
+ return self ._training
114
93
94
+ @training .setter
95
+ def training (self , value ):
96
+ self ._training = value
97
+
98
+ def __call__ (self , inputs , * args , ** kwargs ):
99
+ # TODO(pulkitb): Add cond here to handle training properly.
115
100
x = inputs
116
- if self ._requires_pre_quant ():
117
- x = quant_ops .MovingAvgQuantize (
118
- inputs ,
119
- self ._min_pre_activation ,
120
- self ._max_pre_activation ,
121
- ema_decay = 0.999 ,
122
- is_training = training ,
123
- num_bits = self .num_bits ,
124
- symmetric = self .symmetric ,
125
- name_prefix = self .name )
126
-
127
- x = self .activation (x )
128
- x = quant_ops .MovingAvgQuantize (
129
- x ,
130
- self ._min_post_activation ,
131
- self ._max_post_activation ,
132
- ema_decay = 0.999 ,
133
- is_training = training ,
134
- num_bits = self .num_bits ,
135
- symmetric = self .symmetric ,
136
- name_prefix = self .name )
101
+ if self ._should_pre_quantize ():
102
+ x = self .quantizer (
103
+ x , self .step , self ._training , ** {
104
+ 'min_var' : self ._min_pre_activation ,
105
+ 'max_var' : self ._max_pre_activation
106
+ })
137
107
138
- return x
108
+ x = self . activation ( x , * args , ** kwargs )
139
109
140
- def get_quantize_params (self ):
141
- return {
142
- 'num_bits' : self .num_bits ,
143
- 'symmetric' : self .symmetric ,
144
- }
145
-
146
- def compute_output_shape (self , input_shape ):
147
- return input_shape
148
-
149
- def get_config (self ):
150
- base_config = super (QuantizeAwareActivation , self ).get_config ()
151
- config = {
152
- 'activation' : activations .serialize (self .activation ),
153
- 'parent_layer' : self .parent_layer ,
154
- 'num_bits' : self .num_bits ,
155
- 'symmetric' : self .symmetric ,
156
- }
157
- return dict (list (base_config .items ()) + list (config .items ()))
110
+ x = self .quantizer (
111
+ x , self .step , self ._training , ** {
112
+ 'min_var' : self ._min_post_activation ,
113
+ 'max_var' : self ._max_post_activation
114
+ })
115
+
116
+ return x
0 commit comments