Skip to content

Commit e963045

Browse files
authored
[SpecDecode] Support FlashInfer in DraftModelRunner (#6926)
1 parent 82a1b1a commit e963045

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

vllm/spec_decode/draft_model_runner.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,17 @@
1111
from vllm.attention.backends.rocm_flash_attn import (
1212
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
1313

14+
try:
15+
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
16+
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
17+
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
18+
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
19+
except ImportError:
20+
BatchDecodeWithPagedKVCacheWrapper = None
21+
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
22+
BatchPrefillWithPagedKVCacheWrapper = None
23+
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
24+
1425
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
1526
ModelConfig, MultiModalConfig, ParallelConfig,
1627
PromptAdapterConfig, SchedulerConfig)
@@ -79,6 +90,11 @@ def __init__(
7990
return_hidden_states=return_hidden_states,
8091
)
8192

93+
self.flashinfer_decode_workspace_buffer = None
94+
self.flashinfer_decode_wrapper = None
95+
self.flashinfer_prefill_workspace_buffer = None
96+
self.flashinfer_prefill_wrapper = None
97+
8298
def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
8399
num_queries):
84100
assert isinstance(attn_metadata, FlashAttentionMetadata)
@@ -286,6 +302,37 @@ def execute_model(
286302
model_input.prompt_adapter_requests,
287303
model_input.prompt_adapter_mapping)
288304

305+
if self.attn_backend.get_name() == "flashinfer":
306+
assert model_input.attn_metadata is not None
307+
assert model_input.input_tokens is not None
308+
if self.flashinfer_decode_workspace_buffer is None:
309+
self.flashinfer_decode_workspace_buffer = torch.empty(
310+
FLASHINFER_WORKSPACE_BUFFER_SIZE,
311+
dtype=torch.uint8,
312+
device=self.device)
313+
self.flashinfer_decode_wrapper = \
314+
BatchDecodeWithPagedKVCacheWrapper(
315+
self.flashinfer_decode_workspace_buffer, "NHD")
316+
self.flashinfer_prefill_workspace_buffer = torch.empty(
317+
FLASHINFER_WORKSPACE_BUFFER_SIZE,
318+
dtype=torch.uint8,
319+
device=self.device)
320+
self.flashinfer_prefill_wrapper = \
321+
BatchPrefillWithPagedKVCacheWrapper(
322+
self.flashinfer_prefill_workspace_buffer, "NHD")
323+
324+
model_input.attn_metadata.prefill_wrapper = \
325+
self.flashinfer_prefill_wrapper
326+
if model_input.attn_metadata.use_cuda_graph:
327+
batch_size = model_input.input_tokens.shape[0]
328+
model_input.attn_metadata.decode_wrapper = \
329+
self.graph_runners[model_input.
330+
virtual_engine][batch_size].flashinfer_decode_wrapper
331+
else:
332+
model_input.attn_metadata.decode_wrapper = \
333+
self.flashinfer_decode_wrapper
334+
model_input.attn_metadata.begin_forward()
335+
289336
# Detect exec mode
290337
assert model_input.attn_metadata is not None
291338
use_cuda_graph = False

0 commit comments

Comments
 (0)