4
4
QuantizationArgs ,
5
5
QuantizationScheme ,
6
6
initialize_module_for_quantization ,
7
+ QuantizationStatus ,
7
8
)
8
9
9
- from llmcompressor .modifiers .quantization .calibration import initialize_observer
10
+ from llmcompressor .modifiers .quantization .calibration import (
11
+ initialize_observer ,
12
+ update_weight_zp_scale ,
13
+ update_weight_global_scale ,
14
+ calibrate_input_hook ,
15
+ )
10
16
11
17
12
18
@pytest .mark .parametrize (
@@ -59,3 +65,273 @@ def test_observers_update(shape, group_size, actorder):
59
65
def assert_alike (a , b ):
60
66
assert a .dtype == b .dtype
61
67
assert a .shape == b .shape
68
+
69
+
70
+ @pytest .mark .parametrize (
71
+ "args,exp_min_val,exp_max_val,exp_tol" ,
72
+ [
73
+ (
74
+ QuantizationArgs (
75
+ num_bits = 4 ,
76
+ type = "int" ,
77
+ symmetric = True ,
78
+ strategy = "tensor" ,
79
+ observer = "minmax" ,
80
+ ),
81
+ {"default" : torch .tensor (0.0 , dtype = torch .bfloat16 )},
82
+ {"default" : torch .tensor (23.0 , dtype = torch .bfloat16 )},
83
+ 2.5 ,
84
+ ),
85
+ (
86
+ QuantizationArgs (
87
+ num_bits = 4 ,
88
+ type = "int" ,
89
+ symmetric = True ,
90
+ strategy = "channel" ,
91
+ observer = "minmax" ,
92
+ ),
93
+ {"default" : torch .tensor ([[0 ], [6 ], [12 ], [18 ]], dtype = torch .bfloat16 )},
94
+ {"default" : torch .tensor ([[5 ], [11 ], [17 ], [23 ]], dtype = torch .bfloat16 )},
95
+ 2.5 ,
96
+ ),
97
+ (
98
+ QuantizationArgs (
99
+ num_bits = 4 ,
100
+ type = "int" ,
101
+ symmetric = True ,
102
+ strategy = "group" ,
103
+ group_size = 3 ,
104
+ observer = "minmax" ,
105
+ ),
106
+ {
107
+ "default" : torch .tensor ([[0 ], [6 ], [12 ], [18 ]], dtype = torch .bfloat16 ),
108
+ 1 : torch .tensor ([[3 ], [9 ], [15 ], [21 ]], dtype = torch .bfloat16 ),
109
+ },
110
+ {
111
+ "default" : torch .tensor ([[2 ], [8 ], [14 ], [20 ]], dtype = torch .bfloat16 ),
112
+ 1 : torch .tensor ([[5 ], [11 ], [17 ], [23 ]], dtype = torch .bfloat16 ),
113
+ },
114
+ 2.5 ,
115
+ ),
116
+ (
117
+ QuantizationArgs (
118
+ num_bits = 4 ,
119
+ type = "float" ,
120
+ symmetric = True ,
121
+ strategy = "tensor_group" ,
122
+ group_size = 3 ,
123
+ observer = "minmax" ,
124
+ ),
125
+ {
126
+ "default" : torch .tensor ([[0 ], [6 ], [12 ], [18 ]], dtype = torch .bfloat16 ),
127
+ 1 : torch .tensor ([[3 ], [9 ], [15 ], [21 ]], dtype = torch .bfloat16 ),
128
+ },
129
+ {
130
+ "default" : torch .tensor ([[2 ], [8 ], [14 ], [20 ]], dtype = torch .bfloat16 ),
131
+ 1 : torch .tensor ([[5 ], [11 ], [17 ], [23 ]], dtype = torch .bfloat16 ),
132
+ },
133
+ 5.0 ,
134
+ ),
135
+ (
136
+ QuantizationArgs (
137
+ num_bits = 4 ,
138
+ type = "int" ,
139
+ symmetric = True ,
140
+ strategy = "block" ,
141
+ block_structure = [2 , 3 ],
142
+ observer = "minmax" ,
143
+ ),
144
+ {
145
+ "block_0_0" : torch .tensor ([[0 ]], dtype = torch .bfloat16 ),
146
+ "block_0_1" : torch .tensor ([[3 ]], dtype = torch .bfloat16 ),
147
+ "block_1_0" : torch .tensor ([[12 ]], dtype = torch .bfloat16 ),
148
+ "block_1_1" : torch .tensor ([[15 ]], dtype = torch .bfloat16 ),
149
+ },
150
+ {
151
+ "block_0_0" : torch .tensor ([[8 ]], dtype = torch .bfloat16 ),
152
+ "block_0_1" : torch .tensor ([[11 ]], dtype = torch .bfloat16 ),
153
+ "block_1_0" : torch .tensor ([[20 ]], dtype = torch .bfloat16 ),
154
+ "block_1_1" : torch .tensor ([[23 ]], dtype = torch .bfloat16 ),
155
+ },
156
+ 2.5 ,
157
+ ),
158
+ ],
159
+ )
160
+ def test_weight_quantization (args , exp_min_val , exp_max_val , exp_tol ):
161
+ # set up weight
162
+ input_size , output_size = 6 , 4
163
+ linear = torch .nn .Linear (input_size , output_size , bias = False )
164
+ linear .weight .data = torch .arange (
165
+ input_size * output_size , dtype = torch .bfloat16
166
+ ).reshape (output_size , input_size )
167
+
168
+ # initialize quantization parameters
169
+ scheme = QuantizationScheme (targets = [], weights = args )
170
+ initialize_module_for_quantization (linear , scheme )
171
+ assert getattr (linear , "quantization_scheme" ) is scheme
172
+
173
+ # calibrate quantization parameters
174
+ initialize_observer (linear , "weight" )
175
+ update_weight_global_scale (linear )
176
+ update_weight_zp_scale (linear )
177
+
178
+ observer = getattr (linear , "weight_observer" )
179
+ assert (
180
+ observer .min_val .keys ()
181
+ == observer .max_val .keys ()
182
+ == exp_min_val .keys ()
183
+ == exp_max_val .keys ()
184
+ )
185
+ for key in observer .min_val .keys ():
186
+ assert torch .equal (observer .min_val [key ], exp_min_val [key ])
187
+ assert torch .equal (observer .max_val [key ], exp_max_val [key ])
188
+
189
+ # forward pass
190
+ input = torch .rand ((1 , input_size ), dtype = torch .bfloat16 )
191
+ output = linear (input )
192
+ true_output = input @ linear .weight .T
193
+ assert torch .allclose (output , true_output , atol = exp_tol )
194
+
195
+
196
+ @pytest .mark .parametrize (
197
+ "args,exp_min_val,exp_max_val,exp_tol" ,
198
+ [
199
+ (
200
+ QuantizationArgs (
201
+ num_bits = 4 ,
202
+ type = "int" ,
203
+ symmetric = True ,
204
+ strategy = "tensor" ,
205
+ observer = "minmax" ,
206
+ ),
207
+ {"default" : torch .tensor (0.0 , dtype = torch .bfloat16 )},
208
+ {"default" : torch .tensor (23.0 , dtype = torch .bfloat16 )},
209
+ 2.5 ,
210
+ ),
211
+ (
212
+ QuantizationArgs (
213
+ num_bits = 4 ,
214
+ type = "int" ,
215
+ symmetric = True ,
216
+ strategy = "token" ,
217
+ observer = "minmax" ,
218
+ ),
219
+ {"default" : torch .tensor ([[0 ], [6 ], [12 ], [18 ]], dtype = torch .bfloat16 )},
220
+ {"default" : torch .tensor ([[5 ], [11 ], [17 ], [23 ]], dtype = torch .bfloat16 )},
221
+ 2.5 ,
222
+ ),
223
+ (
224
+ QuantizationArgs (
225
+ num_bits = 4 ,
226
+ type = "int" ,
227
+ symmetric = True ,
228
+ strategy = "channel" ,
229
+ observer = "minmax" ,
230
+ ),
231
+ {
232
+ "default" : torch .tensor ([[0 ], [6 ], [12 ], [18 ]], dtype = torch .bfloat16 ),
233
+ 1 : torch .tensor ([[3 ], [9 ], [15 ], [21 ]], dtype = torch .bfloat16 ),
234
+ },
235
+ {
236
+ "default" : torch .tensor ([[2 ], [8 ], [14 ], [20 ]], dtype = torch .bfloat16 ),
237
+ 1 : torch .tensor ([[5 ], [11 ], [17 ], [23 ]], dtype = torch .bfloat16 ),
238
+ },
239
+ 2.5 ,
240
+ ),
241
+ (
242
+ QuantizationArgs (
243
+ num_bits = 4 ,
244
+ type = "int" ,
245
+ symmetric = True ,
246
+ strategy = "group" ,
247
+ group_size = 3 ,
248
+ observer = "minmax" ,
249
+ ),
250
+ {
251
+ "default" : torch .tensor ([[0 ], [6 ], [12 ], [18 ]], dtype = torch .bfloat16 ),
252
+ 1 : torch .tensor ([[3 ], [9 ], [15 ], [21 ]], dtype = torch .bfloat16 ),
253
+ },
254
+ {
255
+ "default" : torch .tensor ([[2 ], [8 ], [14 ], [20 ]], dtype = torch .bfloat16 ),
256
+ 1 : torch .tensor ([[5 ], [11 ], [17 ], [23 ]], dtype = torch .bfloat16 ),
257
+ },
258
+ 2.5 ,
259
+ ),
260
+ (
261
+ QuantizationArgs (
262
+ num_bits = 4 ,
263
+ type = "float" ,
264
+ symmetric = True ,
265
+ strategy = "tensor_group" ,
266
+ group_size = 3 ,
267
+ observer = "minmax" ,
268
+ ),
269
+ {
270
+ "default" : torch .tensor ([[0 ], [6 ], [12 ], [18 ]], dtype = torch .bfloat16 ),
271
+ 1 : torch .tensor ([[3 ], [9 ], [15 ], [21 ]], dtype = torch .bfloat16 ),
272
+ },
273
+ {
274
+ "default" : torch .tensor ([[2 ], [8 ], [14 ], [20 ]], dtype = torch .bfloat16 ),
275
+ 1 : torch .tensor ([[5 ], [11 ], [17 ], [23 ]], dtype = torch .bfloat16 ),
276
+ },
277
+ 2.5 ,
278
+ ),
279
+ # (
280
+ # QuantizationArgs(
281
+ # num_bits=4,
282
+ # type="int",
283
+ # symmetric=True,
284
+ # strategy="block",
285
+ # block_structure=[2, 3],
286
+ # observer="minmax",
287
+ # ),
288
+ # {
289
+ # "block_0_0": torch.tensor([[0]], dtype=torch.bfloat16),
290
+ # "block_0_1": torch.tensor([[3]], dtype=torch.bfloat16),
291
+ # "block_1_0": torch.tensor([[12]], dtype=torch.bfloat16),
292
+ # "block_1_1": torch.tensor([[15]], dtype=torch.bfloat16),
293
+ # },
294
+ # {
295
+ # "block_0_0": torch.tensor([[8]], dtype=torch.bfloat16),
296
+ # "block_0_1": torch.tensor([[11]], dtype=torch.bfloat16),
297
+ # "block_1_0": torch.tensor([[20]], dtype=torch.bfloat16),
298
+ # "block_1_1": torch.tensor([[23]], dtype=torch.bfloat16),
299
+ # },
300
+ # 2.5,
301
+ # ),
302
+ ],
303
+ )
304
+ def test_activation_quantization (args , exp_min_val , exp_max_val , exp_tol ):
305
+ # set up activation (and identity weight)
306
+ input_size = 6
307
+ input = torch .arange (input_size , dtype = torch .bfloat16 ).unsqueeze (0 )
308
+ linear = torch .nn .Linear (input_size , input_size , bias = False )
309
+ linear .weight .data = torch .eye (input_size , dtype = torch .bfloat16 )
310
+
311
+ # initialize quantization parameters
312
+ scheme = QuantizationScheme (targets = [], input_activations = args )
313
+ initialize_module_for_quantization (linear , scheme )
314
+ assert getattr (linear , "quantization_scheme" ) is scheme
315
+
316
+ # calibrate quantization parameters
317
+ initialize_observer (linear , "input" )
318
+ linear .register_forward_pre_hook (calibrate_input_hook )
319
+
320
+
321
+ observer = getattr (linear , "input_observer" )
322
+ breakpoint ()
323
+ assert (
324
+ observer .min_val .keys ()
325
+ == observer .max_val .keys ()
326
+ == exp_min_val .keys ()
327
+ == exp_max_val .keys ()
328
+ )
329
+ for key in observer .min_val .keys ():
330
+ assert torch .equal (observer .min_val [key ], exp_min_val [key ])
331
+ assert torch .equal (observer .max_val [key ], exp_max_val [key ])
332
+
333
+ # forward pass
334
+ linear .quantization_status = QuantizationStatus .FROZEN
335
+ output = linear (input )
336
+ true_output = input # (@ linear.weight.T == eye)
337
+ assert torch .allclose (output , true_output , atol = exp_tol )
0 commit comments