|
6 | 6 | initialize_module_for_quantization,
|
7 | 7 | )
|
8 | 8 |
|
9 |
| -from llmcompressor.modifiers.quantization.calibration import initialize_observer |
| 9 | +from llmcompressor.modifiers.quantization.calibration import ( |
| 10 | + calibrate_input_hook, |
| 11 | + initialize_observer, |
| 12 | + update_weight_global_scale, |
| 13 | + update_weight_zp_scale, |
| 14 | +) |
10 | 15 |
|
11 | 16 |
|
12 | 17 | @pytest.mark.parametrize(
|
@@ -59,3 +64,320 @@ def test_observers_update(shape, group_size, actorder):
|
59 | 64 | def assert_alike(a, b):
|
60 | 65 | assert a.dtype == b.dtype
|
61 | 66 | assert a.shape == b.shape
|
| 67 | + |
| 68 | + |
| 69 | +@pytest.mark.parametrize( |
| 70 | + "args,exp_min_val,exp_max_val,exp_quant,exp_loss", |
| 71 | + [ |
| 72 | + ( |
| 73 | + QuantizationArgs( |
| 74 | + num_bits=4, |
| 75 | + type="int", |
| 76 | + symmetric=True, |
| 77 | + strategy="tensor", # equivalent to token |
| 78 | + observer="minmax", |
| 79 | + ), |
| 80 | + {"default": torch.tensor(0.0)}, |
| 81 | + {"default": torch.tensor(23.0)}, |
| 82 | + torch.tensor( |
| 83 | + [ |
| 84 | + [0.0000, 0.0000, 3.0625, 3.0625, 3.0625, 6.1250], |
| 85 | + [6.1250, 6.1250, 9.1875, 9.1875, 9.1875, 12.2500], |
| 86 | + [12.2500, 12.2500, 15.3125, 15.3125, 15.3125, 18.3750], |
| 87 | + [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000], |
| 88 | + ], |
| 89 | + dtype=torch.bfloat16, |
| 90 | + ), |
| 91 | + 0.85, |
| 92 | + ), |
| 93 | + ( |
| 94 | + QuantizationArgs( |
| 95 | + num_bits=4, |
| 96 | + type="int", |
| 97 | + symmetric=True, |
| 98 | + strategy="channel", |
| 99 | + observer="minmax", |
| 100 | + ), |
| 101 | + {"default": torch.tensor([[0], [6], [12], [18]])}, |
| 102 | + {"default": torch.tensor([[5], [11], [17], [23]])}, |
| 103 | + torch.tensor( |
| 104 | + [ |
| 105 | + [0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875], |
| 106 | + [5.8750, 7.3438, 7.3438, 8.8125, 10.2500, 10.2500], |
| 107 | + [11.3125, 13.6250, 13.6250, 15.8750, 15.8750, 15.8750], |
| 108 | + [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000], |
| 109 | + ], |
| 110 | + dtype=torch.bfloat16, |
| 111 | + ), |
| 112 | + 0.45, |
| 113 | + ), |
| 114 | + ( |
| 115 | + QuantizationArgs( |
| 116 | + num_bits=4, |
| 117 | + type="int", |
| 118 | + symmetric=True, |
| 119 | + strategy="group", |
| 120 | + group_size=3, |
| 121 | + observer="minmax", |
| 122 | + ), |
| 123 | + { |
| 124 | + "default": torch.tensor([[0], [6], [12], [18]]), |
| 125 | + 1: torch.tensor([[3], [9], [15], [21]]), |
| 126 | + }, |
| 127 | + { |
| 128 | + "default": torch.tensor([[2], [8], [14], [20]]), |
| 129 | + 1: torch.tensor([[5], [11], [17], [23]]), |
| 130 | + }, |
| 131 | + torch.tensor( |
| 132 | + [ |
| 133 | + [0.0000, 1.0703, 1.8750, 2.6719, 4.0000, 4.6875], |
| 134 | + [6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500], |
| 135 | + [11.1875, 13.0625, 13.0625, 15.8750, 15.8750, 15.8750], |
| 136 | + [18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000], |
| 137 | + ], |
| 138 | + ), |
| 139 | + 0.45, |
| 140 | + ), |
| 141 | + ( |
| 142 | + QuantizationArgs( |
| 143 | + num_bits=4, |
| 144 | + type="float", # tensor group requires FP4 |
| 145 | + symmetric=True, |
| 146 | + strategy="tensor_group", # requires float4 |
| 147 | + group_size=3, |
| 148 | + observer="minmax", |
| 149 | + ), |
| 150 | + { |
| 151 | + "default": torch.tensor([[0], [6], [12], [18]]), |
| 152 | + 1: torch.tensor([[3], [9], [15], [21]]), |
| 153 | + }, |
| 154 | + { |
| 155 | + "default": torch.tensor([[2], [8], [14], [20]]), |
| 156 | + 1: torch.tensor([[5], [11], [17], [23]]), |
| 157 | + }, |
| 158 | + torch.tensor( |
| 159 | + [ |
| 160 | + [0.0000, 1.0234, 2.0469, 3.2812, 3.2812, 4.9375], |
| 161 | + [5.4688, 8.1875, 8.1875, 10.6875, 10.6875, 10.6875], |
| 162 | + [9.8750, 14.7500, 14.7500, 16.3750, 16.3750, 16.3750], |
| 163 | + [19.7500, 19.7500, 19.7500, 23.0000, 23.0000, 23.0000], |
| 164 | + ], |
| 165 | + ), |
| 166 | + 1.1, |
| 167 | + ), |
| 168 | + ( |
| 169 | + QuantizationArgs( |
| 170 | + num_bits=4, |
| 171 | + type="int", |
| 172 | + symmetric=True, |
| 173 | + strategy="block", |
| 174 | + block_structure=[2, 3], |
| 175 | + observer="minmax", |
| 176 | + ), |
| 177 | + { |
| 178 | + "block_0_0": torch.tensor([[0]]), |
| 179 | + "block_0_1": torch.tensor([[3]]), |
| 180 | + "block_1_0": torch.tensor([[12]]), |
| 181 | + "block_1_1": torch.tensor([[15]]), |
| 182 | + }, |
| 183 | + { |
| 184 | + "block_0_0": torch.tensor([[8]]), |
| 185 | + "block_0_1": torch.tensor([[11]]), |
| 186 | + "block_1_0": torch.tensor([[20]]), |
| 187 | + "block_1_1": torch.tensor([[23]]), |
| 188 | + }, |
| 189 | + torch.tensor( |
| 190 | + [ |
| 191 | + [0.0000, 1.0703, 2.1406, 2.9375, 4.4062, 4.4062], |
| 192 | + [6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500], |
| 193 | + [10.6875, 13.3750, 13.3750, 15.3125, 15.3125, 18.3750], |
| 194 | + [18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000], |
| 195 | + ], |
| 196 | + ), |
| 197 | + 0.5, |
| 198 | + ), |
| 199 | + ( |
| 200 | + QuantizationArgs( |
| 201 | + num_bits=4, |
| 202 | + type="int", |
| 203 | + symmetric=True, |
| 204 | + strategy="token", # equivalent to tensor |
| 205 | + observer="minmax", |
| 206 | + ), |
| 207 | + {"default": torch.tensor(0.0)}, |
| 208 | + {"default": torch.tensor(23.0)}, |
| 209 | + torch.tensor( |
| 210 | + [ |
| 211 | + [0.0000, 0.0000, 3.0625, 3.0625, 3.0625, 6.1250], |
| 212 | + [6.1250, 6.1250, 9.1875, 9.1875, 9.1875, 12.2500], |
| 213 | + [12.2500, 12.2500, 15.3125, 15.3125, 15.3125, 18.3750], |
| 214 | + [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000], |
| 215 | + ], |
| 216 | + dtype=torch.bfloat16, |
| 217 | + ), |
| 218 | + 0.85, |
| 219 | + ), |
| 220 | + ], |
| 221 | +) |
| 222 | +def test_static_weight_quantization( |
| 223 | + args, exp_min_val, exp_max_val, exp_quant, exp_loss |
| 224 | +): |
| 225 | + """ |
| 226 | + weight = tensor([[ 0, 1, 2, 3, 4, 5], |
| 227 | + [ 6, 7, 8, 9, 10, 11], |
| 228 | + [12, 13, 14, 15, 16, 17], |
| 229 | + [18, 19, 20, 21, 22, 23]]) |
| 230 | + """ |
| 231 | + # set up weight |
| 232 | + input_size, output_size = 6, 4 |
| 233 | + linear = torch.nn.Linear(input_size, output_size, bias=False) |
| 234 | + linear.weight.data = torch.arange( |
| 235 | + input_size * output_size, dtype=torch.bfloat16 |
| 236 | + ).reshape(output_size, input_size) |
| 237 | + |
| 238 | + # initialize quantization parameters |
| 239 | + scheme = QuantizationScheme(targets=[], weights=args) |
| 240 | + initialize_module_for_quantization(linear, scheme) |
| 241 | + assert getattr(linear, "quantization_scheme") is scheme |
| 242 | + |
| 243 | + # calibrate quantization parameters |
| 244 | + initialize_observer(linear, "weight") |
| 245 | + update_weight_global_scale(linear) |
| 246 | + update_weight_zp_scale(linear) |
| 247 | + |
| 248 | + observer = getattr(linear, "weight_observer") |
| 249 | + assert ( |
| 250 | + observer.min_val.keys() |
| 251 | + == observer.max_val.keys() |
| 252 | + == exp_min_val.keys() |
| 253 | + == exp_max_val.keys() |
| 254 | + ) |
| 255 | + for key in observer.min_val.keys(): |
| 256 | + assert torch.equal(observer.min_val[key], exp_min_val[key]) |
| 257 | + assert torch.equal(observer.max_val[key], exp_max_val[key]) |
| 258 | + |
| 259 | + # forward pass |
| 260 | + input = torch.eye(input_size, dtype=torch.bfloat16) |
| 261 | + output = linear(input) |
| 262 | + |
| 263 | + print(output.T) |
| 264 | + print(torch.nn.functional.mse_loss(output.T, linear.weight)) |
| 265 | + assert torch.allclose(output.T, exp_quant.to(output.dtype)) |
| 266 | + assert torch.nn.functional.mse_loss(output.T, linear.weight) <= exp_loss |
| 267 | + |
| 268 | + |
| 269 | +@pytest.mark.parametrize( |
| 270 | + "args,exp_min_val,exp_max_val,exp_quant,exp_loss", |
| 271 | + [ |
| 272 | + ( |
| 273 | + QuantizationArgs( |
| 274 | + num_bits=4, |
| 275 | + type="int", |
| 276 | + symmetric=True, |
| 277 | + strategy="tensor", # equivalent to token |
| 278 | + observer="minmax", |
| 279 | + ), |
| 280 | + {"default": torch.tensor(0.0)}, |
| 281 | + {"default": torch.tensor(5.0)}, |
| 282 | + torch.tensor([[0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875]]), |
| 283 | + 0.06, |
| 284 | + ), |
| 285 | + ( |
| 286 | + QuantizationArgs( |
| 287 | + num_bits=4, |
| 288 | + type="int", |
| 289 | + symmetric=True, |
| 290 | + strategy="token", # equivalent to tensor |
| 291 | + observer="minmax", |
| 292 | + ), |
| 293 | + {"default": torch.tensor(0.0)}, |
| 294 | + {"default": torch.tensor(5.0)}, |
| 295 | + torch.tensor([[0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875]]), |
| 296 | + 0.06, |
| 297 | + ), |
| 298 | + # channel is not supported, but is in principle equivalent to token/tensor |
| 299 | + # ( |
| 300 | + # QuantizationArgs( |
| 301 | + # num_bits=4, |
| 302 | + # type="int", |
| 303 | + # symmetric=True, |
| 304 | + # strategy="group", |
| 305 | + # group_size=3, |
| 306 | + # observer="minmax", |
| 307 | + # ), |
| 308 | + # { |
| 309 | + # "default": torch.tensor([[0]]), |
| 310 | + # 1: torch.tensor([[3]]), |
| 311 | + # }, |
| 312 | + # { |
| 313 | + # "default": torch.tensor([[2]]), |
| 314 | + # 1: torch.tensor([[5]]), |
| 315 | + # }, |
| 316 | + # torch.tensor([[0.0000, 1.0703, 1.8750, 2.6719, 4.0000, 4.6875]]), |
| 317 | + # 0.04, |
| 318 | + # ), |
| 319 | + # ( |
| 320 | + # QuantizationArgs( |
| 321 | + # num_bits=4, |
| 322 | + # type="float", # tensor group requires FP4 |
| 323 | + # symmetric=True, |
| 324 | + # strategy="tensor_group", |
| 325 | + # group_size=3, |
| 326 | + # observer="minmax", |
| 327 | + # ), |
| 328 | + # { |
| 329 | + # "default": torch.tensor([[0]]), |
| 330 | + # 1: torch.tensor([[3]]), |
| 331 | + # }, |
| 332 | + # { |
| 333 | + # "default": torch.tensor([[2]]), |
| 334 | + # 1: torch.tensor([[5]]), |
| 335 | + # }, |
| 336 | + # torch.tensor([[0.0000, 0.9766, 1.9531, 3.3125, 3.3125, 4.9688]]), |
| 337 | + # 0.1, |
| 338 | + # ), |
| 339 | + # block is not supported, but is in principle similar to group |
| 340 | + ], |
| 341 | +) |
| 342 | +def test_static_activation_quantization( |
| 343 | + args, exp_min_val, exp_max_val, exp_quant, exp_loss |
| 344 | +): |
| 345 | + """ |
| 346 | + input = tensor([[ 0, 1, 2, 3, 4, 5]]) |
| 347 | + """ |
| 348 | + # set up activation (and identity weight) |
| 349 | + input_size = 6 |
| 350 | + input = torch.arange(input_size, dtype=torch.bfloat16).unsqueeze(0) |
| 351 | + linear = torch.nn.Linear(input_size, input_size, bias=False) |
| 352 | + linear.weight.data = torch.eye(input_size, dtype=torch.bfloat16) |
| 353 | + |
| 354 | + # initialize quantization parameters |
| 355 | + scheme = QuantizationScheme(targets=[], input_activations=args) |
| 356 | + initialize_module_for_quantization(linear, scheme) |
| 357 | + assert getattr(linear, "quantization_scheme") is scheme |
| 358 | + |
| 359 | + # calibrate quantization parameters |
| 360 | + initialize_observer(linear, "input") |
| 361 | + linear.register_forward_pre_hook(calibrate_input_hook) |
| 362 | + |
| 363 | + # calibration forward pass |
| 364 | + output = linear(input) |
| 365 | + |
| 366 | + # check calibration |
| 367 | + observer = getattr(linear, "input_observer") |
| 368 | + assert ( |
| 369 | + observer.min_val.keys() |
| 370 | + == observer.max_val.keys() |
| 371 | + == exp_min_val.keys() |
| 372 | + == exp_max_val.keys() |
| 373 | + ) |
| 374 | + for key in observer.min_val.keys(): |
| 375 | + assert torch.equal(observer.min_val[key], exp_min_val[key]) |
| 376 | + assert torch.equal(observer.max_val[key], exp_max_val[key]) |
| 377 | + |
| 378 | + # check forward pass |
| 379 | + print(args.strategy) |
| 380 | + print(output) |
| 381 | + print(torch.nn.functional.mse_loss(output, input)) |
| 382 | + assert torch.allclose(output, exp_quant.to(output.dtype)) |
| 383 | + assert torch.nn.functional.mse_loss(output, input) <= exp_loss |
0 commit comments