Skip to content

Commit bf2ddc6

Browse files
[bugfix] Fix static asymmetric quantization case (#10334)
Signed-off-by: Daniël de Kok <[email protected]> Signed-off-by: luka <[email protected]> Co-authored-by: Daniël de Kok <[email protected]>
1 parent 972112d commit bf2ddc6

File tree

5 files changed

+58
-15
lines changed

5 files changed

+58
-15
lines changed

tests/kernels/test_int8_quant.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
8686
assert torch_out.min() >= int8_traits.min and torch_out.max(
8787
) <= int8_traits.max
8888

89-
ops_out = torch.empty_like(x, dtype=torch.int8)
90-
scales_out = torch.empty_like(scales, dtype=torch.float32)
91-
azp_out = torch.empty_like(azps, dtype=torch.int32)
92-
torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out, azp_out)
89+
ops_out, scales_out, azp_out = scaled_int8_quant(x, symmetric=False)
9390

9491
if (not torch.allclose(scales_out, scales)):
9592
print(torch.argmax(torch.abs(scales_out - scales)))
@@ -119,7 +116,8 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
119116

120117
out1 = (x / scale_arg).round().clamp(int8_traits.min,
121118
int8_traits.max).to(torch.int8)
122-
out2, _, _ = scaled_int8_quant(x, scale_arg)
119+
out2, scale2, _ = scaled_int8_quant(x, scale_arg)
120+
assert scale2 is scale_arg
123121

124122
# big atol to account for rounding errors
125123
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
@@ -145,11 +143,15 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
145143

146144
out1 = ((x / scale).round() + azp).clamp(int8_traits.min,
147145
int8_traits.max).to(torch.int8)
148-
out2 = torch.empty_like(x, dtype=torch.int8)
149146
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
150147
azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda")
151148

152-
torch.ops._C.static_scaled_int8_quant(out2, x, scale_arg, azp_arg)
149+
out2, scale2, azp2 = scaled_int8_quant(x,
150+
scale_arg,
151+
azp_arg,
152+
symmetric=False)
153+
assert scale2 is scale_arg
154+
assert azp2 is azp_arg
153155

154156
# big atol to account for rounding errors
155157
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
@@ -184,6 +186,5 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
184186
val_i8 = int8_traits.max if is_max else int8_traits.min
185187
expected = torch.full((1, 5), val_i8, dtype=torch.int8, device="cuda")
186188

187-
out = torch.empty_like(expected)
188-
torch.ops._C.static_scaled_int8_quant(out, x, scale, azp)
189+
out, _, _ = scaled_int8_quant(x, scale, azp, symmetric=False)
189190
torch.testing.assert_close(expected, out, atol=0, rtol=0)

tests/quantization/test_compressed_tensors.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from compressed_tensors.quantization import QuantizationType
1010

11+
from tests.models.utils import check_logprobs_close
1112
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
1213
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
1314
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
@@ -74,6 +75,35 @@ def zp_valid(zp: Optional[torch.Tensor]):
7475
assert output
7576

7677

78+
@pytest.mark.parametrize(
79+
"model_path",
80+
[
81+
"neuralmagic/Llama-3.2-1B-quantized.w8a8"
82+
# TODO static & asymmetric
83+
])
84+
@pytest.mark.parametrize("max_tokens", [32])
85+
@pytest.mark.parametrize("num_logprobs", [10])
86+
def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner,
87+
example_prompts, model_path,
88+
max_tokens, num_logprobs):
89+
dtype = "bfloat16"
90+
91+
with hf_runner(model_path, dtype=dtype) as hf_model:
92+
hf_outputs = hf_model.generate_greedy_logprobs_limit(
93+
example_prompts, max_tokens, num_logprobs)
94+
95+
with vllm_runner(model_path, dtype=dtype) as vllm_model:
96+
vllm_outputs = vllm_model.generate_greedy_logprobs(
97+
example_prompts, max_tokens, num_logprobs)
98+
99+
check_logprobs_close(
100+
outputs_0_lst=hf_outputs,
101+
outputs_1_lst=vllm_outputs,
102+
name_0="hf",
103+
name_1="vllm",
104+
)
105+
106+
77107
def test_compressed_tensors_no_enforce_eager(vllm_runner):
78108
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
79109
with vllm_runner(model_path) as llm:

vllm/_custom_ops.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,10 +510,16 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
510510
azp_adj: torch.Tensor,
511511
azp: Optional[torch.Tensor] = None,
512512
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
513+
"""
514+
:param azp_adj: In the per-tensor case, this should include the azp.
515+
Always per-channel.
516+
:param azp: Only set in the per-token case. Per-token if set.
517+
"""
513518
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
514519
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
515520
assert bias is None or bias.numel(
516521
) == b.shape[1] and bias.dtype == out_dtype
522+
assert azp is None or azp.numel() == a.shape[0]
517523

518524
m = a.shape[0]
519525
n = b.shape[1]
@@ -735,7 +741,7 @@ def scaled_int8_quant(
735741
azp is
736742
None), "azp must only be provided for asymmetric quantization."
737743
torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
738-
return output, scale, None
744+
return output, scale, azp
739745

740746
# dynamic-per-token quantization.
741747
input_scales = torch.empty((input.numel() // input.shape[-1], 1),

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
8282
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
8383
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
8484
if not self.input_symmetric:
85-
layer.azp_adj = layer.weight.sum(dim=0,
86-
keepdim=True,
87-
dtype=torch.int32)
85+
azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32)
86+
if self.is_static_input_scheme:
87+
# cutlass_w8a8 requires azp to be folded into azp_adj
88+
# in the per-tensor case
89+
azp_adj = layer.input_zero_point * azp_adj
90+
91+
layer.azp_adj = azp_adj
8892
else:
8993
layer.azp_adj = None
9094

@@ -138,7 +142,6 @@ def create_weights(self, layer: torch.nn.Module,
138142

139143
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
140144
bias: Optional[torch.Tensor]) -> torch.Tensor:
141-
142145
return apply_int8_linear(input=x,
143146
weight=layer.weight,
144147
weight_scale=layer.weight_scale,

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,16 @@ def apply_int8_linear(
211211
symmetric=symmetric)
212212

213213
if x_zp is not None:
214+
# Currently, static is always per-tensor and dynamic is per-token
215+
static = input_zero_point is not None
216+
azp = None if static else x_zp
214217
return ops.cutlass_scaled_mm_azp(x_q,
215218
weight,
216219
scale_a=x_scale,
217220
scale_b=weight_scale,
218221
out_dtype=input.dtype,
219222
azp_adj=azp_adj,
220-
azp=x_zp,
223+
azp=azp,
221224
bias=bias)
222225
return ops.cutlass_scaled_mm(x_q,
223226
weight,

0 commit comments

Comments
 (0)