Skip to content

Commit e42af78

Browse files
authored
[flashinfer] [kernel] support for fp8 kv cache for trtllm prefill attention (#24197)
Signed-off-by: Xiaozhu <[email protected]>
1 parent 074854b commit e42af78

File tree

3 files changed

+121
-14
lines changed

3 files changed

+121
-14
lines changed

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@
163163
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
164164
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
165165
VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None
166+
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False
166167
VLLM_HAS_FLASHINFER_CUBIN: bool = False
167168
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
168169
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
@@ -1155,6 +1156,10 @@ def get_vllm_port() -> Optional[int]:
11551156
"VLLM_USE_TRTLLM_ATTENTION":
11561157
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
11571158

1159+
# If set to 1, when we use fp8 kv, we do not quantize Q to fp8
1160+
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION":
1161+
lambda: bool(int(os.getenv("VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "0"))),
1162+
11581163
# If set, it means we pre-downloaded cubin files and flashinfer will
11591164
# read the cubin files directly.
11601165
"VLLM_HAS_FLASHINFER_CUBIN":
@@ -1310,6 +1315,7 @@ def compute_hash() -> str:
13101315
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
13111316
"VLLM_USE_CUDNN_PREFILL",
13121317
"VLLM_USE_TRTLLM_ATTENTION",
1318+
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
13131319
"VLLM_ROCM_USE_AITER",
13141320
"VLLM_ROCM_USE_AITER_PAGED_ATTN",
13151321
"VLLM_ROCM_USE_AITER_LINEAR",

vllm/utils/flashinfer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,6 @@ def use_trtllm_attention(
200200
logger.info_once("Using TRTLLM attention (query is quantized).")
201201
return True
202202

203-
# TRTLLM prefill attention does not support FP8 kv cache with
204-
# non-quantized query
205-
if is_prefill and kv_cache_dtype.startswith("fp8"):
206-
return False
207-
208203
# If sinks are being used, we must use TRTLLM attention as it's
209204
# the only backend that supports them
210205
if has_sinks:
@@ -353,6 +348,12 @@ def flashinfer_scaled_fp8_mm(
353348
return output
354349

355350

351+
@functools.cache
352+
def flashinfer_disable_q_quantization() -> bool:
353+
"""Cache result which only depends on the environment"""
354+
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
355+
356+
356357
__all__ = [
357358
"has_flashinfer",
358359
"flashinfer_trtllm_fp8_block_scale_moe",

vllm/v1/attention/backends/flashinfer.py

Lines changed: 109 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from vllm.platforms import current_platform
2626
from vllm.triton_utils import tl, triton
2727
from vllm.utils import cdiv, is_pin_memory_available
28-
from vllm.utils.flashinfer import (supports_trtllm_attention,
28+
from vllm.utils.flashinfer import (flashinfer_disable_q_quantization,
29+
supports_trtllm_attention,
2930
use_trtllm_attention)
3031
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
3132
# yapf conflicts with isort for this block
@@ -48,8 +49,89 @@
4849
logger = init_logger(__name__)
4950

5051

51-
class FlashInferBackend(AttentionBackend):
52+
@triton.jit
53+
def _trtllm_prefill_attn_kvfp8_dequant(
54+
kv_cache_ptr,
55+
block_tables_prefill_ptr,
56+
block_table_stride,
57+
mock_kv_cache_ptr,
58+
k_scale_ptr,
59+
v_scale_ptr,
60+
K_CACHE_STRIDE: tl.constexpr,
61+
KV_CACHE_STRIDE: tl.constexpr,
62+
):
63+
batch_idx = tl.program_id(0).to(tl.int64)
64+
mock_block_table_idx = tl.program_id(1).to(tl.int64)
65+
orig_page_num = tl.load(block_tables_prefill_ptr +
66+
batch_idx * block_table_stride +
67+
mock_block_table_idx).to(tl.int64)
68+
if orig_page_num <= 0:
69+
return
70+
dequant_dtype = mock_kv_cache_ptr.dtype.element_ty
71+
72+
# Dequantize K
73+
k_scale_val = tl.load(k_scale_ptr)
74+
offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
75+
fp8_vals = tl.load(kv_cache_ptr + offset)
76+
dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val
77+
mock_cache_offset = (batch_idx * block_table_stride + mock_block_table_idx
78+
+ 1) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
79+
dequantized_vals = dequantized_vals.to(dequant_dtype)
80+
tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)
81+
82+
# Dequantize V
83+
v_scale_val = tl.load(v_scale_ptr)
84+
offset = (orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE +
85+
tl.arange(0, K_CACHE_STRIDE))
86+
fp8_vals = tl.load(kv_cache_ptr + offset)
87+
dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val
88+
mock_cache_offset = (
89+
(batch_idx * block_table_stride + mock_block_table_idx + 1) *
90+
KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE))
91+
dequantized_vals = dequantized_vals.to(dequant_dtype)
92+
tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)
93+
94+
95+
def trtllm_prefill_attn_kvfp8_dequant(
96+
kv_cache: torch.Tensor,
97+
block_tables_prefill: torch.Tensor,
98+
k_scale: torch.Tensor,
99+
v_scale: torch.Tensor,
100+
dequant_dtype: torch.dtype,
101+
) -> tuple[torch.Tensor, torch.Tensor]:
102+
batch_size, num_of_page_per_token = block_tables_prefill.shape
103+
s = kv_cache.shape
104+
assert s[1] == 2
105+
assert dequant_dtype in (torch.bfloat16, torch.float16)
106+
k_cache_stride = s[2] * s[3] * s[4]
107+
kv_cache_stride = k_cache_stride * s[1]
108+
new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4])
109+
# mock kv cache contains just the pages needed by this prefill
110+
mock_kv_cache = torch.empty(new_s,
111+
dtype=dequant_dtype,
112+
device=kv_cache.device)
113+
# we simply sequentially index the pages needed by this prefill
114+
mock_block_table = torch.arange(
115+
start=1,
116+
end=batch_size * num_of_page_per_token + 1,
117+
dtype=torch.int32,
118+
device=block_tables_prefill.device,
119+
).reshape(batch_size, num_of_page_per_token)
120+
grid = (batch_size, num_of_page_per_token)
121+
_trtllm_prefill_attn_kvfp8_dequant[grid](
122+
kv_cache,
123+
block_tables_prefill,
124+
num_of_page_per_token,
125+
mock_kv_cache,
126+
k_scale,
127+
v_scale,
128+
k_cache_stride,
129+
kv_cache_stride,
130+
)
131+
return mock_kv_cache, mock_block_table
132+
52133

134+
class FlashInferBackend(AttentionBackend):
53135
accept_output_buffer: bool = True
54136

55137
@classmethod
@@ -122,7 +204,6 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
122204

123205
@dataclass
124206
class FlashInferMetadata:
125-
126207
num_actual_tokens: int # Number of tokens excluding padding.
127208

128209
# The data type of the query
@@ -175,8 +256,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
175256
self.kv_cache_spec.block_size)
176257
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
177258
max_num_pages = max_num_reqs * max_num_pages_per_req
178-
self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\
179-
decode_mode() == CUDAGraphMode.FULL
259+
self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\
260+
decode_mode() == CUDAGraphMode.FULL)
180261
if self.enable_cuda_graph:
181262
# For full cudagraph capture, one `decode_wrapper` for each batch
182263
# size is needed for FlashInfer.
@@ -201,7 +282,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
201282
assert self.kv_cache_spec.dtype == self.model_config.dtype
202283
self.kv_cache_dtype = self.kv_cache_spec.dtype
203284

204-
if supports_trtllm_attention()[0]:
285+
if supports_trtllm_attention()[0] and \
286+
not flashinfer_disable_q_quantization():
205287
self.q_data_type = self.kv_cache_dtype
206288
else:
207289
self.q_data_type = self.model_config.dtype
@@ -795,11 +877,29 @@ def forward(
795877
assert self.o_sf_scale is None
796878
out = output[num_decode_tokens:]
797879

880+
if attn_metadata.q_data_type != FP8_DTYPE \
881+
and self.kv_cache_dtype.startswith("fp8"):
882+
# TRTLLM prefill attention does not support BF16 Q
883+
# and fp8 kv cache. So to enable prefill attention
884+
# with fp8 kv cache, we can construct a mock block
885+
# and mock kv cache with BF16 KV involved in the prefill
886+
mock_kv_cache, mock_block_table = (
887+
trtllm_prefill_attn_kvfp8_dequant(
888+
kv_cache_permute,
889+
block_tables_prefill,
890+
layer._k_scale,
891+
layer._v_scale,
892+
attn_metadata.q_data_type,
893+
))
894+
else:
895+
mock_kv_cache = kv_cache_permute
896+
mock_block_table = block_tables_prefill
897+
798898
trtllm_batch_context_with_kv_cache(
799899
query=prefill_query,
800-
kv_cache=kv_cache_permute,
900+
kv_cache=mock_kv_cache,
801901
workspace_buffer=workspace_buffer,
802-
block_tables=block_tables_prefill,
902+
block_tables=mock_block_table,
803903
seq_lens=seq_lens_prefill,
804904
max_q_len=attn_metadata.max_q_len,
805905
max_kv_len=attn_metadata.max_seq_len,
@@ -837,7 +937,7 @@ def forward(
837937
decode_query = decode_query.contiguous()
838938
workspace_buffer = decode_wrapper._float_workspace_buffer
839939
block_tables_decode = attn_metadata.\
840-
block_table_tensor[:num_decode_tokens]
940+
block_table_tensor[:num_decode_tokens]
841941
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
842942

843943
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND

0 commit comments

Comments
 (0)