29
29
"freeze_module_quantization" ,
30
30
"apply_calibration_status" ,
31
31
"reset_quantization_status" ,
32
- "update_weight_global_scale" ,
33
32
]
34
33
35
34
@@ -67,13 +66,7 @@ def initialize_observer(
67
66
module .register_module (f"{ base_name } _observer" , observer )
68
67
69
68
70
- def call_observer (
71
- module : Module ,
72
- base_name : str ,
73
- value : Optional [torch .Tensor ] = None ,
74
- should_calculate_gparam : bool = False ,
75
- should_calculate_qparams : bool = True ,
76
- ):
69
+ def call_observer (module : Module , base_name : str , value : Optional [torch .Tensor ] = None ):
77
70
"""
78
71
Call a module's attached input/weight/output observer using a provided value.
79
72
Update the module's scale and zp using the observer's return values.
@@ -87,51 +80,54 @@ def call_observer(
87
80
if base_name == "weight" :
88
81
value = module .weight
89
82
g_idx = getattr (module , "weight_g_idx" , None )
83
+ global_scale = getattr (module , f"{ base_name } _global_scale" , None )
90
84
elif value is not None :
91
85
g_idx = None
86
+ global_scale = None
92
87
else :
93
88
raise ValueError (
94
89
"Must provide a value to observe if not using weight observer"
95
90
)
96
91
92
+ quantization_scheme = getattr (module , "quantization_scheme" , None )
93
+ arg_name = "weights" if base_name == "weight" else f"{ base_name } _activations"
94
+ quant_args = getattr (quantization_scheme , arg_name , None )
95
+
96
+ # We always calculate quantizaton parameters by default and no global parameters
97
+ should_calculate_gparam = False
98
+ should_calculate_qparams = True
99
+
100
+ # TODO: will update to be the case for both weight and input in a follow-up
101
+ # weight global calculate is currently done in ct right now;
102
+ # should be moved here to unify global scale calculations
103
+ if (
104
+ quant_args .strategy == QuantizationStrategy .TENSOR_GROUP
105
+ and base_name == "input"
106
+ ):
107
+ should_calculate_gparam = True
108
+ should_calculate_qparams = False
109
+
97
110
observer = getattr (module , f"{ base_name } _observer" )
111
+ observer_outputs = observer (
112
+ value ,
113
+ g_idx = g_idx ,
114
+ global_scale = global_scale ,
115
+ should_calculate_gparam = should_calculate_gparam ,
116
+ )
98
117
99
118
if should_calculate_gparam :
100
- global_scale = observer (
101
- value ,
102
- should_calculate_gparam = True ,
119
+ updated_global_scale = observer_outputs
120
+ update_parameter_data (
121
+ module , updated_global_scale , f" { base_name } _global_scale"
103
122
)
104
- update_parameter_data (module , global_scale , f"{ base_name } _global_scale" )
105
- else :
106
- global_scale = getattr (module , f"{ base_name } _global_scale" , None )
107
123
108
124
if should_calculate_qparams :
109
- updated_scale , updated_zero_point = observer (
110
- value , g_idx = g_idx , global_scale = global_scale
111
- )
125
+ # update scale and zero point
126
+ updated_scale , updated_zero_point = observer_outputs
112
127
update_parameter_data (module , updated_scale , f"{ base_name } _scale" )
113
128
update_parameter_data (module , updated_zero_point , f"{ base_name } _zero_point" )
114
129
115
130
116
- def update_weight_global_scale (module : Module ):
117
- if getattr_chain (module , "quantization_scheme.weights" , None ) is None :
118
- return
119
-
120
- if (
121
- getattr_chain (module , "quantization_scheme.weights.strategy" , None )
122
- != QuantizationStrategy .TENSOR_GROUP
123
- ):
124
- return
125
-
126
- call_observer (
127
- module ,
128
- base_name = "weight" ,
129
- should_calculate_gparam = True ,
130
- should_calculate_qparams = False ,
131
- )
132
- module .weight_observer .reset ()
133
-
134
-
135
131
def update_weight_zp_scale (module : Module ):
136
132
"""
137
133
marks a layer as ready for calibration which activates observers
@@ -169,24 +165,10 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
169
165
if value .numel () == 0 :
170
166
return
171
167
172
- quantization_scheme = getattr (module , "quantization_scheme" , None )
173
- quantization_args = getattr (quantization_scheme , f"{ base_name } _activations" , None )
174
-
175
- calculate_qparams = True
176
- calculate_gparam = False
177
-
178
- if quantization_args is not None :
179
- if quantization_args .dynamic in (True , DynamicType .LOCAL ):
180
- calculate_qparams = False
181
- if quantization_args .strategy == QuantizationStrategy .TENSOR_GROUP :
182
- calculate_gparam = True
183
-
184
168
call_observer (
185
169
module = module ,
186
170
base_name = base_name ,
187
171
value = value ,
188
- should_calculate_gparam = calculate_gparam ,
189
- should_calculate_qparams = calculate_qparams ,
190
172
)
191
173
192
174
0 commit comments