|
3 | 3 | from typing import Dict, List, Optional, Tuple, Union
|
4 | 4 |
|
5 | 5 | import torch
|
6 |
| -from compressed_tensors.quantization import disable_quantization |
| 6 | +from compressed_tensors.quantization import QuantizationType, disable_quantization |
7 | 7 | from compressed_tensors.utils import (
|
8 | 8 | align_modules,
|
9 | 9 | get_execution_device,
|
@@ -126,6 +126,7 @@ class AWQModifier(Modifier, QuantizationMixin):
|
126 | 126 |
|
127 | 127 | # Private vars set during validation
|
128 | 128 | _num_bits: Optional[int] = PrivateAttr(default=None)
|
| 129 | + _activation_bits: int = PrivateAttr(default=16) |
129 | 130 | _symmetric: Optional[bool] = PrivateAttr(default=None)
|
130 | 131 | _group_size: Optional[int] = PrivateAttr(default=None)
|
131 | 132 |
|
@@ -189,6 +190,18 @@ def validate_model_after(model: "AWQModifier") -> "AWQModifier":
|
189 | 190 | if act is not None
|
190 | 191 | }
|
191 | 192 | if not (len(num_bits_set) == 0 or num_bits_set == {16}):
|
| 193 | + num_bits_type = { |
| 194 | + act.type |
| 195 | + for group in config.config_groups.values() |
| 196 | + for act in (group.input_activations, group.output_activations) |
| 197 | + if act is not None |
| 198 | + } |
| 199 | + assert ( |
| 200 | + next(iter(num_bits_type)) == QuantizationType.FLOAT |
| 201 | + ), "In AWQ, lower-precision activation quantization must be float" |
| 202 | + |
| 203 | + model._activation_bits = next(iter(num_bits_set)) |
| 204 | + |
192 | 205 | warnings.warn(
|
193 | 206 | "A strategy including activation quantization was detected. "
|
194 | 207 | "AWQ was originally intended for weight-only quantization. "
|
@@ -612,16 +625,26 @@ def _compute_best_scale(
|
612 | 625 | # Q(W * s)
|
613 | 626 | for linear in linears2scale:
|
614 | 627 | linear.weight.mul_(_scalesview)
|
615 |
| - update_offload_parameter( |
616 |
| - linear, |
617 |
| - "weight", |
| 628 | + scaled_weight = ( |
618 | 629 | _pseudo_quantize_tensor(
|
619 | 630 | w=linear.weight.data,
|
620 | 631 | symmetric=self._symmetric,
|
621 | 632 | bit_width=self._num_bits,
|
622 | 633 | group_size=self._group_size,
|
623 | 634 | )[0]
|
624 |
| - / _scalesview, |
| 635 | + / _scalesview |
| 636 | + ) |
| 637 | + |
| 638 | + # fp8 activation simulation |
| 639 | + if self._activation_bits == 8: |
| 640 | + scaled_weight = scaled_weight.to(torch.float8_e4m3fn).to( |
| 641 | + torch.float16 |
| 642 | + ) |
| 643 | + |
| 644 | + update_offload_parameter( |
| 645 | + linear, |
| 646 | + "weight", |
| 647 | + scaled_weight, |
625 | 648 | )
|
626 | 649 |
|
627 | 650 | # W * X
|
|
0 commit comments