25
25
from vllm .platforms import current_platform
26
26
from vllm .triton_utils import tl , triton
27
27
from vllm .utils import cdiv , is_pin_memory_available
28
- from vllm .utils .flashinfer import (supports_trtllm_attention ,
28
+ from vllm .utils .flashinfer import (flashinfer_disable_q_quantization ,
29
+ supports_trtllm_attention ,
29
30
use_trtllm_attention )
30
31
from vllm .v1 .attention .backends .flash_attn import use_cascade_attention
31
32
# yapf conflicts with isort for this block
48
49
logger = init_logger (__name__ )
49
50
50
51
51
- class FlashInferBackend (AttentionBackend ):
52
+ @triton .jit
53
+ def _trtllm_prefill_attn_kvfp8_dequant (
54
+ kv_cache_ptr ,
55
+ block_tables_prefill_ptr ,
56
+ block_table_stride ,
57
+ mock_kv_cache_ptr ,
58
+ k_scale_ptr ,
59
+ v_scale_ptr ,
60
+ K_CACHE_STRIDE : tl .constexpr ,
61
+ KV_CACHE_STRIDE : tl .constexpr ,
62
+ ):
63
+ batch_idx = tl .program_id (0 ).to (tl .int64 )
64
+ mock_block_table_idx = tl .program_id (1 ).to (tl .int64 )
65
+ orig_page_num = tl .load (block_tables_prefill_ptr +
66
+ batch_idx * block_table_stride +
67
+ mock_block_table_idx ).to (tl .int64 )
68
+ if orig_page_num <= 0 :
69
+ return
70
+ dequant_dtype = mock_kv_cache_ptr .dtype .element_ty
71
+
72
+ # Dequantize K
73
+ k_scale_val = tl .load (k_scale_ptr )
74
+ offset = orig_page_num * KV_CACHE_STRIDE + tl .arange (0 , K_CACHE_STRIDE )
75
+ fp8_vals = tl .load (kv_cache_ptr + offset )
76
+ dequantized_vals = fp8_vals .to (tl .float32 ) * k_scale_val
77
+ mock_cache_offset = (batch_idx * block_table_stride + mock_block_table_idx
78
+ + 1 ) * KV_CACHE_STRIDE + tl .arange (0 , K_CACHE_STRIDE )
79
+ dequantized_vals = dequantized_vals .to (dequant_dtype )
80
+ tl .store (mock_kv_cache_ptr + mock_cache_offset , dequantized_vals )
81
+
82
+ # Dequantize V
83
+ v_scale_val = tl .load (v_scale_ptr )
84
+ offset = (orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE +
85
+ tl .arange (0 , K_CACHE_STRIDE ))
86
+ fp8_vals = tl .load (kv_cache_ptr + offset )
87
+ dequantized_vals = fp8_vals .to (tl .float32 ) * v_scale_val
88
+ mock_cache_offset = (
89
+ (batch_idx * block_table_stride + mock_block_table_idx + 1 ) *
90
+ KV_CACHE_STRIDE + K_CACHE_STRIDE + tl .arange (0 , K_CACHE_STRIDE ))
91
+ dequantized_vals = dequantized_vals .to (dequant_dtype )
92
+ tl .store (mock_kv_cache_ptr + mock_cache_offset , dequantized_vals )
93
+
94
+
95
+ def trtllm_prefill_attn_kvfp8_dequant (
96
+ kv_cache : torch .Tensor ,
97
+ block_tables_prefill : torch .Tensor ,
98
+ k_scale : torch .Tensor ,
99
+ v_scale : torch .Tensor ,
100
+ dequant_dtype : torch .dtype ,
101
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
102
+ batch_size , num_of_page_per_token = block_tables_prefill .shape
103
+ s = kv_cache .shape
104
+ assert s [1 ] == 2
105
+ assert dequant_dtype in (torch .bfloat16 , torch .float16 )
106
+ k_cache_stride = s [2 ] * s [3 ] * s [4 ]
107
+ kv_cache_stride = k_cache_stride * s [1 ]
108
+ new_s = (batch_size * num_of_page_per_token + 1 , s [1 ], s [2 ], s [3 ], s [4 ])
109
+ # mock kv cache contains just the pages needed by this prefill
110
+ mock_kv_cache = torch .empty (new_s ,
111
+ dtype = dequant_dtype ,
112
+ device = kv_cache .device )
113
+ # we simply sequentially index the pages needed by this prefill
114
+ mock_block_table = torch .arange (
115
+ start = 1 ,
116
+ end = batch_size * num_of_page_per_token + 1 ,
117
+ dtype = torch .int32 ,
118
+ device = block_tables_prefill .device ,
119
+ ).reshape (batch_size , num_of_page_per_token )
120
+ grid = (batch_size , num_of_page_per_token )
121
+ _trtllm_prefill_attn_kvfp8_dequant [grid ](
122
+ kv_cache ,
123
+ block_tables_prefill ,
124
+ num_of_page_per_token ,
125
+ mock_kv_cache ,
126
+ k_scale ,
127
+ v_scale ,
128
+ k_cache_stride ,
129
+ kv_cache_stride ,
130
+ )
131
+ return mock_kv_cache , mock_block_table
132
+
52
133
134
+ class FlashInferBackend (AttentionBackend ):
53
135
accept_output_buffer : bool = True
54
136
55
137
@classmethod
@@ -122,7 +204,6 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
122
204
123
205
@dataclass
124
206
class FlashInferMetadata :
125
-
126
207
num_actual_tokens : int # Number of tokens excluding padding.
127
208
128
209
# The data type of the query
@@ -175,8 +256,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
175
256
self .kv_cache_spec .block_size )
176
257
max_num_reqs = vllm_config .scheduler_config .max_num_seqs
177
258
max_num_pages = max_num_reqs * max_num_pages_per_req
178
- self .enable_cuda_graph = self .compilation_config .cudagraph_mode .\
179
- decode_mode () == CUDAGraphMode .FULL
259
+ self .enable_cuda_graph = ( self .compilation_config .cudagraph_mode .\
260
+ decode_mode () == CUDAGraphMode .FULL )
180
261
if self .enable_cuda_graph :
181
262
# For full cudagraph capture, one `decode_wrapper` for each batch
182
263
# size is needed for FlashInfer.
@@ -201,7 +282,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
201
282
assert self .kv_cache_spec .dtype == self .model_config .dtype
202
283
self .kv_cache_dtype = self .kv_cache_spec .dtype
203
284
204
- if supports_trtllm_attention ()[0 ]:
285
+ if supports_trtllm_attention ()[0 ] and \
286
+ not flashinfer_disable_q_quantization ():
205
287
self .q_data_type = self .kv_cache_dtype
206
288
else :
207
289
self .q_data_type = self .model_config .dtype
@@ -795,11 +877,29 @@ def forward(
795
877
assert self .o_sf_scale is None
796
878
out = output [num_decode_tokens :]
797
879
880
+ if attn_metadata .q_data_type != FP8_DTYPE \
881
+ and self .kv_cache_dtype .startswith ("fp8" ):
882
+ # TRTLLM prefill attention does not support BF16 Q
883
+ # and fp8 kv cache. So to enable prefill attention
884
+ # with fp8 kv cache, we can construct a mock block
885
+ # and mock kv cache with BF16 KV involved in the prefill
886
+ mock_kv_cache , mock_block_table = (
887
+ trtllm_prefill_attn_kvfp8_dequant (
888
+ kv_cache_permute ,
889
+ block_tables_prefill ,
890
+ layer ._k_scale ,
891
+ layer ._v_scale ,
892
+ attn_metadata .q_data_type ,
893
+ ))
894
+ else :
895
+ mock_kv_cache = kv_cache_permute
896
+ mock_block_table = block_tables_prefill
897
+
798
898
trtllm_batch_context_with_kv_cache (
799
899
query = prefill_query ,
800
- kv_cache = kv_cache_permute ,
900
+ kv_cache = mock_kv_cache ,
801
901
workspace_buffer = workspace_buffer ,
802
- block_tables = block_tables_prefill ,
902
+ block_tables = mock_block_table ,
803
903
seq_lens = seq_lens_prefill ,
804
904
max_q_len = attn_metadata .max_q_len ,
805
905
max_kv_len = attn_metadata .max_seq_len ,
@@ -837,7 +937,7 @@ def forward(
837
937
decode_query = decode_query .contiguous ()
838
938
workspace_buffer = decode_wrapper ._float_workspace_buffer
839
939
block_tables_decode = attn_metadata .\
840
- block_table_tensor [:num_decode_tokens ]
940
+ block_table_tensor [:num_decode_tokens ]
841
941
seq_lens_decode = attn_metadata .seq_lens [:num_decode_tokens ]
842
942
843
943
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
0 commit comments