@@ -157,7 +157,7 @@ def assert_alike(a, b):
157
157
),
158
158
],
159
159
)
160
- def test_weight_quantization (args , exp_min_val , exp_max_val , exp_tol ):
160
+ def test_static_weight_quantization (args , exp_min_val , exp_max_val , exp_tol ):
161
161
# set up weight
162
162
input_size , output_size = 6 , 4
163
163
linear = torch .nn .Linear (input_size , output_size , bias = False )
@@ -205,7 +205,7 @@ def test_weight_quantization(args, exp_min_val, exp_max_val, exp_tol):
205
205
observer = "minmax" ,
206
206
),
207
207
{"default" : torch .tensor (0.0 , dtype = torch .bfloat16 )},
208
- {"default" : torch .tensor (23 .0 , dtype = torch .bfloat16 )},
208
+ {"default" : torch .tensor (5 .0 , dtype = torch .bfloat16 )},
209
209
2.5 ,
210
210
),
211
211
(
@@ -216,28 +216,23 @@ def test_weight_quantization(args, exp_min_val, exp_max_val, exp_tol):
216
216
strategy = "token" ,
217
217
observer = "minmax" ,
218
218
),
219
- {"default" : torch .tensor ([[0 ], [6 ], [12 ], [18 ]], dtype = torch .bfloat16 )},
220
- {"default" : torch .tensor ([[5 ], [11 ], [17 ], [23 ]], dtype = torch .bfloat16 )},
221
- 2.5 ,
222
- ),
223
- (
224
- QuantizationArgs (
225
- num_bits = 4 ,
226
- type = "int" ,
227
- symmetric = True ,
228
- strategy = "channel" ,
229
- observer = "minmax" ,
230
- ),
231
- {
232
- "default" : torch .tensor ([[0 ], [6 ], [12 ], [18 ]], dtype = torch .bfloat16 ),
233
- 1 : torch .tensor ([[3 ], [9 ], [15 ], [21 ]], dtype = torch .bfloat16 ),
234
- },
235
- {
236
- "default" : torch .tensor ([[2 ], [8 ], [14 ], [20 ]], dtype = torch .bfloat16 ),
237
- 1 : torch .tensor ([[5 ], [11 ], [17 ], [23 ]], dtype = torch .bfloat16 ),
238
- },
219
+ {"default" : torch .tensor (0.0 , dtype = torch .bfloat16 )},
220
+ {"default" : torch .tensor (5.0 , dtype = torch .bfloat16 )},
239
221
2.5 ,
240
222
),
223
+ # channel is not supported, but (tensor == token == channel)
224
+ # (
225
+ # QuantizationArgs(
226
+ # num_bits=4,
227
+ # type="int",
228
+ # symmetric=True,
229
+ # strategy="channel",
230
+ # observer="minmax",
231
+ # ),
232
+ # {"default": torch.tensor(0.0, dtype=torch.bfloat16)},
233
+ # {"default": torch.tensor(5.0, dtype=torch.bfloat16)},
234
+ # 2.5,
235
+ # ),
241
236
(
242
237
QuantizationArgs (
243
238
num_bits = 4 ,
@@ -248,12 +243,12 @@ def test_weight_quantization(args, exp_min_val, exp_max_val, exp_tol):
248
243
observer = "minmax" ,
249
244
),
250
245
{
251
- "default" : torch .tensor ([[0 ], [ 6 ], [ 12 ], [ 18 ] ], dtype = torch .bfloat16 ),
252
- 1 : torch .tensor ([[3 ], [ 9 ], [ 15 ], [ 21 ] ], dtype = torch .bfloat16 ),
246
+ "default" : torch .tensor ([[0 ]], dtype = torch .bfloat16 ),
247
+ 1 : torch .tensor ([[3 ]], dtype = torch .bfloat16 ),
253
248
},
254
249
{
255
- "default" : torch .tensor ([[2 ], [ 8 ], [ 14 ], [ 20 ] ], dtype = torch .bfloat16 ),
256
- 1 : torch .tensor ([[5 ], [ 11 ], [ 17 ], [ 23 ] ], dtype = torch .bfloat16 ),
250
+ "default" : torch .tensor ([[2 ]], dtype = torch .bfloat16 ),
251
+ 1 : torch .tensor ([[5 ]], dtype = torch .bfloat16 ),
257
252
},
258
253
2.5 ,
259
254
),
@@ -267,12 +262,12 @@ def test_weight_quantization(args, exp_min_val, exp_max_val, exp_tol):
267
262
observer = "minmax" ,
268
263
),
269
264
{
270
- "default" : torch .tensor ([[0 ], [ 6 ], [ 12 ], [ 18 ] ], dtype = torch .bfloat16 ),
271
- 1 : torch .tensor ([[3 ], [ 9 ], [ 15 ], [ 21 ] ], dtype = torch .bfloat16 ),
265
+ "default" : torch .tensor ([[0 ]], dtype = torch .bfloat16 ),
266
+ 1 : torch .tensor ([[3 ]], dtype = torch .bfloat16 ),
272
267
},
273
268
{
274
- "default" : torch .tensor ([[2 ], [ 8 ], [ 14 ], [ 20 ] ], dtype = torch .bfloat16 ),
275
- 1 : torch .tensor ([[5 ], [ 11 ], [ 17 ], [ 23 ] ], dtype = torch .bfloat16 ),
269
+ "default" : torch .tensor ([[2 ]], dtype = torch .bfloat16 ),
270
+ 1 : torch .tensor ([[5 ]], dtype = torch .bfloat16 ),
276
271
},
277
272
2.5 ,
278
273
),
@@ -301,7 +296,7 @@ def test_weight_quantization(args, exp_min_val, exp_max_val, exp_tol):
301
296
# ),
302
297
],
303
298
)
304
- def test_activation_quantization (args , exp_min_val , exp_max_val , exp_tol ):
299
+ def test_static_activation_quantization (args , exp_min_val , exp_max_val , exp_tol ):
305
300
# set up activation (and identity weight)
306
301
input_size = 6
307
302
input = torch .arange (input_size , dtype = torch .bfloat16 ).unsqueeze (0 )
@@ -317,7 +312,11 @@ def test_activation_quantization(args, exp_min_val, exp_max_val, exp_tol):
317
312
initialize_observer (linear , "input" )
318
313
linear .register_forward_pre_hook (calibrate_input_hook )
319
314
315
+ # calibration forward pass
316
+ output = linear (input )
317
+ assert torch .allclose (output , input , atol = exp_tol )
320
318
319
+ # check calibration
321
320
observer = getattr (linear , "input_observer" )
322
321
breakpoint ()
323
322
assert (
@@ -328,10 +327,4 @@ def test_activation_quantization(args, exp_min_val, exp_max_val, exp_tol):
328
327
)
329
328
for key in observer .min_val .keys ():
330
329
assert torch .equal (observer .min_val [key ], exp_min_val [key ])
331
- assert torch .equal (observer .max_val [key ], exp_max_val [key ])
332
-
333
- # forward pass
334
- linear .quantization_status = QuantizationStatus .FROZEN
335
- output = linear (input )
336
- true_output = input # (@ linear.weight.T == eye)
337
- assert torch .allclose (output , true_output , atol = exp_tol )
330
+ assert torch .equal (observer .max_val [key ], exp_max_val [key ])
0 commit comments