Skip to content

Commit 9a0c5de

Browse files
authored
[TPU] Add support for online w8a8 quantization (#22425)
Signed-off-by: Kyuyeun Kim <[email protected]>
1 parent 10a0253 commit 9a0c5de

File tree

3 files changed

+82
-3
lines changed

3 files changed

+82
-3
lines changed

.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ run_and_track_test 5 "test_spmd_model_weight_loading.py" \
139139
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
140140
run_and_track_test 6 "test_kv_cache_update_kernel.py" \
141141
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
142+
run_and_track_test 7 "test_tpu_int8.py" \
143+
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_int8.py"
142144
143145
# After all tests have been attempted, exit with the overall status.
144146
if [ "$overall_script_exit_code" -ne 0 ]; then

tests/v1/tpu/test_tpu_int8.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Tests whether TPU Int8 computation is enabled correctly.
4+
5+
Run `pytest tests/quantization/test_tpu_int8.py`.
6+
"""
7+
import pytest
8+
9+
from vllm.model_executor.layers.linear import LinearBase
10+
from vllm.model_executor.layers.quantization.tpu_int8 import (
11+
TPUInt8LinearMethod)
12+
from vllm.platforms import current_platform
13+
14+
from ...models.registry import HF_EXAMPLE_MODELS
15+
16+
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
17+
18+
19+
@pytest.mark.skipif(not current_platform.is_tpu(),
20+
reason="TPU Int8 is only enabled for TPUs.")
21+
@pytest.mark.parametrize("model", MODELS)
22+
@pytest.mark.parametrize("dtype", ["bfloat16"])
23+
@pytest.mark.parametrize("max_tokens", [10])
24+
@pytest.mark.parametrize(
25+
"hf_overrides",
26+
[
27+
# w8a8 dynamic activation
28+
{
29+
'quantization_config': {
30+
'quant_method': 'tpu_int8',
31+
'activation_scheme': 'dynamic'
32+
}
33+
}
34+
])
35+
def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int,
36+
hf_overrides: dict, monkeypatch) -> None:
37+
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
38+
model_info.check_transformers_version(on_fail="skip")
39+
40+
activation_scheme = hf_overrides.get('quantization_config',
41+
{}).get('activation_scheme')
42+
quantize_activation = activation_scheme == 'dynamic'
43+
44+
# Allows using apply_model
45+
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
46+
# Prevent error from re-initializing cache
47+
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "")
48+
49+
prompts = [
50+
"A robot may not injure a human being",
51+
"It is only with the heart that one can see rightly;",
52+
"The greatest glory in living lies not in never falling,",
53+
]
54+
answers = [
55+
"or, being injured, not kill, except in",
56+
"without the heart, one can only see wrongly.",
57+
"but in rising every time we fall. - Nelson"
58+
]
59+
60+
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
61+
62+
def check_model(model):
63+
for name, module in model.named_modules():
64+
if not isinstance(module, LinearBase):
65+
continue
66+
quant_method = module.quant_method
67+
assert isinstance(quant_method, TPUInt8LinearMethod)
68+
assert quant_method.quantize_activation == quantize_activation
69+
70+
vllm.apply_model(check_model)
71+
outputs = vllm.generate_greedy(prompts, max_tokens)
72+
for (_, output), answer in zip(outputs, answers):
73+
assert answer in output

vllm/model_executor/layers/quantization/tpu_int8.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
QuantizationConfig)
1414
from vllm.model_executor.parameter import ModelWeightParameter
1515

16-
ACTIVATION_SCHEMES = ["none"]
16+
ACTIVATION_SCHEMES = ["none", "dynamic"]
1717

1818

1919
class Int8TpuConfig(QuantizationConfig):
@@ -61,6 +61,9 @@ class TPUInt8LinearMethod(LinearMethodBase):
6161

6262
def __init__(self, quant_config: Int8TpuConfig):
6363
self.quant_config = quant_config
64+
self.quantize_activation = False
65+
if self.quant_config.activation_scheme == 'dynamic':
66+
self.quantize_activation = True
6467

6568
def create_weights(self, layer: Module, input_size_per_partition: int,
6669
output_partition_sizes: list[int], input_size: int,
@@ -107,15 +110,16 @@ def apply(self,
107110
x: torch.Tensor,
108111
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
109112
try:
110-
import torch_xla.experimental.xla_quantized_matmul # noqa: F401
113+
import torch_xla.experimental.custom_kernel # noqa: F401
111114
except ImportError as err:
112115
raise ImportError(
113116
"Please install torch_xla by following the instructions at "
114117
"https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501
115118
"to run vLLM on TPU.") from err
116119
weight = layer.weight
117120
scale = layer.scale
118-
out = torch.ops.xla.quantized_matmul(x, weight, scale)
121+
out = torch.ops.xla.quantized_matmul_int8(
122+
x, weight, scale, quantize_activation=self.quantize_activation)
119123
if bias is not None:
120124
out = out + bias
121125
return out

0 commit comments

Comments
 (0)