|
11 | 11 | from vllm.attention.backends.rocm_flash_attn import (
|
12 | 12 | ROCmFlashAttentionMetadata as FlashAttentionMetadata)
|
13 | 13 |
|
| 14 | +try: |
| 15 | + from flashinfer import BatchDecodeWithPagedKVCacheWrapper |
| 16 | + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper |
| 17 | + from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper |
| 18 | + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 |
| 19 | +except ImportError: |
| 20 | + BatchDecodeWithPagedKVCacheWrapper = None |
| 21 | + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None |
| 22 | + BatchPrefillWithPagedKVCacheWrapper = None |
| 23 | + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 |
| 24 | + |
14 | 25 | from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
15 | 26 | ModelConfig, MultiModalConfig, ParallelConfig,
|
16 | 27 | PromptAdapterConfig, SchedulerConfig)
|
@@ -79,6 +90,11 @@ def __init__(
|
79 | 90 | return_hidden_states=return_hidden_states,
|
80 | 91 | )
|
81 | 92 |
|
| 93 | + self.flashinfer_decode_workspace_buffer = None |
| 94 | + self.flashinfer_decode_wrapper = None |
| 95 | + self.flashinfer_prefill_workspace_buffer = None |
| 96 | + self.flashinfer_prefill_wrapper = None |
| 97 | + |
82 | 98 | def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
|
83 | 99 | num_queries):
|
84 | 100 | assert isinstance(attn_metadata, FlashAttentionMetadata)
|
@@ -286,6 +302,37 @@ def execute_model(
|
286 | 302 | model_input.prompt_adapter_requests,
|
287 | 303 | model_input.prompt_adapter_mapping)
|
288 | 304 |
|
| 305 | + if self.attn_backend.get_name() == "flashinfer": |
| 306 | + assert model_input.attn_metadata is not None |
| 307 | + assert model_input.input_tokens is not None |
| 308 | + if self.flashinfer_decode_workspace_buffer is None: |
| 309 | + self.flashinfer_decode_workspace_buffer = torch.empty( |
| 310 | + FLASHINFER_WORKSPACE_BUFFER_SIZE, |
| 311 | + dtype=torch.uint8, |
| 312 | + device=self.device) |
| 313 | + self.flashinfer_decode_wrapper = \ |
| 314 | + BatchDecodeWithPagedKVCacheWrapper( |
| 315 | + self.flashinfer_decode_workspace_buffer, "NHD") |
| 316 | + self.flashinfer_prefill_workspace_buffer = torch.empty( |
| 317 | + FLASHINFER_WORKSPACE_BUFFER_SIZE, |
| 318 | + dtype=torch.uint8, |
| 319 | + device=self.device) |
| 320 | + self.flashinfer_prefill_wrapper = \ |
| 321 | + BatchPrefillWithPagedKVCacheWrapper( |
| 322 | + self.flashinfer_prefill_workspace_buffer, "NHD") |
| 323 | + |
| 324 | + model_input.attn_metadata.prefill_wrapper = \ |
| 325 | + self.flashinfer_prefill_wrapper |
| 326 | + if model_input.attn_metadata.use_cuda_graph: |
| 327 | + batch_size = model_input.input_tokens.shape[0] |
| 328 | + model_input.attn_metadata.decode_wrapper = \ |
| 329 | + self.graph_runners[model_input. |
| 330 | + virtual_engine][batch_size].flashinfer_decode_wrapper |
| 331 | + else: |
| 332 | + model_input.attn_metadata.decode_wrapper = \ |
| 333 | + self.flashinfer_decode_wrapper |
| 334 | + model_input.attn_metadata.begin_forward() |
| 335 | + |
289 | 336 | # Detect exec mode
|
290 | 337 | assert model_input.attn_metadata is not None
|
291 | 338 | use_cuda_graph = False
|
|
0 commit comments