Skip to content

Commit 3a1d894

Browse files
authored
[TPU] support fp8 kv cache quantization (#19292)
Signed-off-by: Chengji Yao <[email protected]>
1 parent 2b504eb commit 3a1d894

File tree

6 files changed

+95
-28
lines changed

6 files changed

+95
-28
lines changed

tests/entrypoints/llm/test_accuracy.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,18 @@
1515
from vllm.platforms import current_platform
1616

1717
MODEL_NAMES = [
18-
"Qwen/Qwen2-1.5B-Instruct",
18+
"Qwen/Qwen3-1.7B",
1919
"google/gemma-3-1b-it",
2020
]
21+
FP8_KV_MODEL_NAMES = [
22+
"Qwen/Qwen3-1.7B",
23+
]
2124
NUM_CONCURRENT = 500
2225
TASK = "gsm8k"
2326
FILTER = "exact_match,strict-match"
2427
RTOL = 0.03
2528
EXPECTED_VALUES = {
26-
"Qwen/Qwen2-1.5B-Instruct": 0.58,
29+
"Qwen/Qwen3-1.7B": 0.68,
2730
"google/gemma-3-1b-it": 0.25,
2831
}
2932

@@ -70,10 +73,9 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
7073
if current_platform.is_tpu():
7174
# Limit compilation time for TPU V1
7275

73-
if model == "google/gemma-3-1b-it":
74-
# TPU + google/gemma-3-1b-it + xet doesn't work well.
75-
m.setenv("HF_HUB_DISABLE_XET", "1")
76-
76+
# xet doesn't work well for both Qwen/Qwen3-1.7B and
77+
# google/gemma-3-1b-it
78+
m.setenv("HF_HUB_DISABLE_XET", "1")
7779
more_args = "max_model_len=2048,max_num_seqs=64"
7880

7981
# Add TP test (if provided)
@@ -83,9 +85,27 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
8385
run_test(model, more_args)
8486

8587

86-
def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch):
87-
"""Run with the V0 Engine."""
88+
@pytest.mark.skipif(not current_platform.is_cuda()
89+
and not current_platform.is_tpu(),
90+
reason="V1 is currently only supported on CUDA and TPU")
91+
@pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES)
92+
def test_lm_eval_accuracy_v1_engine_fp8_kv_cache(
93+
model, monkeypatch: pytest.MonkeyPatch):
94+
"""Run with the V1 Engine."""
8895

8996
with monkeypatch.context() as m:
90-
m.setenv("VLLM_USE_V1", "0")
91-
run_test("Qwen/Qwen2-1.5B-Instruct")
97+
m.setenv("VLLM_USE_V1", "1")
98+
99+
more_args = None
100+
if current_platform.is_tpu():
101+
# Limit compilation time for TPU V1
102+
103+
# xet doesn't work well for Qwen/Qwen3-1.7B
104+
m.setenv("HF_HUB_DISABLE_XET", "1")
105+
more_args = "max_model_len=2048,max_num_seqs=128,kv_cache_dtype=fp8"
106+
107+
# Add TP test (if provided)
108+
if TPU_TP_TEST_STR:
109+
more_args += ",{}".format(TPU_TP_TEST_STR)
110+
111+
run_test(model, more_args)

tests/v1/tpu/test_pallas.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,6 @@ class FakeAttentionLayer:
9595
sm_scale=scale,
9696
sliding_window=sliding_window,
9797
soft_cap=logits_soft_cap,
98+
k_scale=1.0,
99+
v_scale=1.0,
98100
)

vllm/engine/arg_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,10 +1358,10 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13581358
and not envs.is_set("VLLM_ATTENTION_BACKEND")
13591359
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
13601360
supported = False
1361-
if current_platform.is_rocm() or (
1362-
current_platform.is_cuda()
1363-
and current_platform.is_device_capability(100)
1364-
): # handle hpu also for OOT platform
1361+
if (current_platform.is_rocm()
1362+
or (current_platform.is_cuda()
1363+
and current_platform.is_device_capability(100))
1364+
or current_platform.is_tpu()):
13651365
supported = True
13661366
elif fp8_attention and will_use_fa:
13671367
from vllm.attention.utils.fa_utils import (

vllm/platforms/tpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ class TpuPlatform(Platform):
3535
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
3636
simple_compile_backend: str = "openxla"
3737

38-
supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"]
38+
supported_quantization: list[str] = [
39+
"fp8", "tpu_int8", "compressed-tensors"
40+
]
3941

4042
additional_env_vars: list[str] = [
4143
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"

vllm/v1/attention/backends/pallas.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@
2424
# TPU requires the head size to be a multiple of 128.
2525
TPU_HEAD_SIZE_ALIGNMENT = 128
2626

27+
# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
28+
# from to fp32 directly. That's why it has a dtype mapping different from GPU
29+
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
30+
"half": torch.half,
31+
"bfloat16": torch.bfloat16,
32+
"float": torch.float,
33+
"fp8": torch.float8_e4m3fn,
34+
"fp8_e4m3": torch.float8_e4m3fn,
35+
"fp8_e5m2": torch.float8_e5m2,
36+
"int8": torch.int8,
37+
"uint8": torch.uint8,
38+
}
39+
2740

2841
class PallasAttentionBackend(AttentionBackend):
2942

@@ -152,15 +165,18 @@ def __init__(
152165
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
153166
if alibi_slopes is not None:
154167
raise NotImplementedError("Alibi slopes is not supported.")
155-
if kv_cache_dtype != "auto":
156-
raise NotImplementedError("FP8 KV cache dtype is not supported.")
157168

158169
if attn_type != AttentionType.DECODER:
159170
raise NotImplementedError("Encoder self-attention and "
160171
"encoder/decoder cross-attention "
161172
"are not implemented for "
162173
"PallasAttentionBackendImpl")
163174

175+
self.kv_cache_quantized_dtype = None
176+
if kv_cache_dtype != "auto":
177+
self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get(
178+
kv_cache_dtype.lower().strip())
179+
164180
def forward(
165181
self,
166182
layer: AttentionLayer,
@@ -194,7 +210,6 @@ def forward(
194210
output = torch.ones_like(query)
195211
return output
196212

197-
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
198213
num_tokens, hidden_size = query.shape
199214
query = query.view(num_tokens, self.num_heads, self.head_size)
200215
key = key.view(-1, self.num_kv_heads, self.head_size)
@@ -215,10 +230,21 @@ def forward(
215230
# Skip this if sharing KV cache with an earlier attention layer.
216231
slot_mapping = attn_metadata.slot_mapping
217232
write_to_kv_cache(
218-
key, value, kv_cache, slot_mapping,
233+
key,
234+
value,
235+
kv_cache,
236+
slot_mapping,
219237
attn_metadata.num_slices_per_kv_cache_update_block,
220-
attn_metadata.num_kv_update_slices)
221-
238+
attn_metadata.num_kv_update_slices,
239+
self.kv_cache_quantized_dtype,
240+
layer._k_scale_float,
241+
layer._v_scale_float,
242+
)
243+
244+
if self.kv_cache_quantized_dtype is not None and (
245+
layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0):
246+
raise ValueError(
247+
"k_scale_float and v_scale_float must be non-zero")
222248
output = torch.ops.xla.ragged_paged_attention(
223249
query,
224250
kv_cache,
@@ -236,6 +262,8 @@ def forward(
236262
sm_scale=self.scale,
237263
sliding_window=self.sliding_window,
238264
soft_cap=self.logits_soft_cap,
265+
k_scale=layer._k_scale_float,
266+
v_scale=layer._v_scale_float,
239267
)
240268

241269
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
@@ -251,18 +279,32 @@ def write_to_kv_cache(
251279
slot_mapping: torch.Tensor,
252280
num_slices_per_kv_cache_update_block: int,
253281
num_kv_update_slices: torch.Tensor,
282+
kv_cache_quantized_dtype: Optional[torch.dtype] = None,
283+
k_scale: float = 1.0,
284+
v_scale: float = 1.0,
254285
) -> None:
255286
""" Write the key and values to the KV cache.
256287
257288
Args:
258-
key: shape = [num_tokens, num_kv_heads * head_size]
259-
value: shape = [num_tokens, num_kv_heads * head_size]
289+
key: shape = [num_tokens, num_kv_heads, head_size]
290+
value: shape = [num_tokens, num_kv_heads, head_size]
260291
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
261292
num_slices_per_kv_cache_update_block: int
262293
"""
263294
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
264295
head_size = cdiv(head_size,
265296
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
297+
298+
if kv_cache_quantized_dtype is not None:
299+
dtype_info = torch.finfo(kv_cache_quantized_dtype)
300+
key = key.to(torch.float32) / k_scale
301+
# NOTE: clamp is added here to avoid out of range of quantized dtype
302+
key = torch.clamp(key, dtype_info.min, dtype_info.max)
303+
key = key.to(kv_cache_quantized_dtype)
304+
value = value.to(torch.float32) / v_scale
305+
value = torch.clamp(value, dtype_info.min, dtype_info.max)
306+
value = value.to(kv_cache_quantized_dtype)
307+
266308
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
267309
head_size)
268310

vllm/v1/worker/tpu_model_runner.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@
3232
from vllm.multimodal.utils import group_mm_inputs_by_modality
3333
from vllm.pooling_params import PoolingTask
3434
from vllm.sequence import IntermediateTensors
35-
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
36-
is_pin_memory_available, prev_power_of_2)
37-
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
35+
from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available,
36+
prev_power_of_2)
37+
from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE,
38+
PallasAttentionBackend,
3839
PallasMetadata,
3940
get_page_size_bytes)
4041
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@@ -142,11 +143,11 @@ def __init__(
142143
if cache_config.cache_dtype == "auto":
143144
model_dtype = self.dtype
144145
if isinstance(model_dtype, str):
145-
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
146+
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
146147
else:
147148
self.kv_cache_dtype = model_dtype
148149
else:
149-
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
150+
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
150151
cache_config.cache_dtype]
151152
self._hidden_states_dtype = self.dtype
152153

0 commit comments

Comments
 (0)