Skip to content

Commit bb9391a

Browse files
author
Bluedyson
committed
Add awq activation fp8 support in loss compute
Signed-off-by: Bluedyson <[email protected]>
1 parent ef26dc4 commit bb9391a

File tree

2 files changed

+62
-6
lines changed

2 files changed

+62
-6
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Dict, List, Optional, Tuple, Union
44

55
import torch
6-
from compressed_tensors.quantization import disable_quantization
6+
from compressed_tensors.quantization import QuantizationType, disable_quantization
77
from compressed_tensors.utils import (
88
align_modules,
99
get_execution_device,
@@ -126,6 +126,7 @@ class AWQModifier(Modifier, QuantizationMixin):
126126

127127
# Private vars set during validation
128128
_num_bits: Optional[int] = PrivateAttr(default=None)
129+
_activation_bits: int = PrivateAttr(default=16)
129130
_symmetric: Optional[bool] = PrivateAttr(default=None)
130131
_group_size: Optional[int] = PrivateAttr(default=None)
131132

@@ -189,6 +190,18 @@ def validate_model_after(model: "AWQModifier") -> "AWQModifier":
189190
if act is not None
190191
}
191192
if not (len(num_bits_set) == 0 or num_bits_set == {16}):
193+
num_bits_type = {
194+
act.type
195+
for group in config.config_groups.values()
196+
for act in (group.input_activations, group.output_activations)
197+
if act is not None
198+
}
199+
assert (
200+
next(iter(num_bits_type)) == QuantizationType.FLOAT
201+
), "In AWQ, lower-precision activation quantization must be float"
202+
203+
model._activation_bits = next(iter(num_bits_set))
204+
192205
warnings.warn(
193206
"A strategy including activation quantization was detected. "
194207
"AWQ was originally intended for weight-only quantization. "
@@ -612,16 +625,26 @@ def _compute_best_scale(
612625
# Q(W * s)
613626
for linear in linears2scale:
614627
linear.weight.mul_(_scalesview)
615-
update_offload_parameter(
616-
linear,
617-
"weight",
628+
scaled_weight = (
618629
_pseudo_quantize_tensor(
619630
w=linear.weight.data,
620631
symmetric=self._symmetric,
621632
bit_width=self._num_bits,
622633
group_size=self._group_size,
623634
)[0]
624-
/ _scalesview,
635+
/ _scalesview
636+
)
637+
638+
# fp8 activation simulation
639+
if self._activation_bits == 8:
640+
scaled_weight = scaled_weight.to(torch.float8_e4m3fn).to(
641+
torch.float16
642+
)
643+
644+
update_offload_parameter(
645+
linear,
646+
"weight",
647+
scaled_weight,
625648
)
626649

627650
# W * X

tests/llmcompressor/modifiers/awq/test_base.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import pytest
22
import torch
3-
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
3+
from compressed_tensors.quantization import (
4+
QuantizationArgs,
5+
QuantizationScheme,
6+
QuantizationType,
7+
)
48
from pydantic import ValidationError
59

610
from llmcompressor.modifiers.awq import AWQMapping, AWQModifier
@@ -154,6 +158,25 @@ def test_validate():
154158
}
155159
)
156160

161+
with pytest.raises(ValidationError):
162+
AWQModifier(
163+
config_groups={
164+
"group_0": QuantizationScheme(
165+
targets=["Linear"],
166+
weights=QuantizationArgs(
167+
num_bits=4,
168+
group_size=128,
169+
),
170+
input_activations=QuantizationArgs(
171+
num_bits=8, type=QuantizationType.INT
172+
),
173+
output_activations=QuantizationArgs(
174+
num_bits=8, type=QuantizationType.INT
175+
),
176+
),
177+
}
178+
)
179+
157180
# valid configuration
158181
AWQModifier(
159182
config_groups={
@@ -165,6 +188,16 @@ def test_validate():
165188
targets=["Linear"],
166189
weights=QuantizationArgs(num_bits=4, group_size=128, symmetric=False),
167190
),
191+
"group_2": QuantizationScheme(
192+
targets=["Linear"],
193+
weights=QuantizationArgs(num_bits=4, group_size=128, symmetric=False),
194+
input_activations=QuantizationArgs(
195+
num_bits=8, type=QuantizationType.FLOAT
196+
),
197+
output_activations=QuantizationArgs(
198+
num_bits=8, type=QuantizationType.FLOAT
199+
),
200+
),
168201
}
169202
)
170203

0 commit comments

Comments
 (0)