Skip to content

Commit 5d81cbb

Browse files
tdoublepyyihuang
authored andcommitted
[V1] [Hybrid] Enable Full CUDA Graph (decode-only) for Mamba layers (vllm-project#21401)
Signed-off-by: Thomas Parnell <[email protected]> Signed-off-by: Avery Yingyi Huang <[email protected]>
1 parent 62e174d commit 5d81cbb

File tree

2 files changed

+103
-1
lines changed

2 files changed

+103
-1
lines changed

tests/models/language/generation/test_hybrid.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,3 +384,63 @@ def test_distributed_correctness(
384384
name_0="vllm_tp_1",
385385
name_1="vllm_tp_2",
386386
)
387+
388+
389+
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
390+
@pytest.mark.parametrize("max_tokens", [64])
391+
@pytest.mark.parametrize("num_logprobs", [5])
392+
def test_full_cuda_graph(
393+
hf_runner,
394+
vllm_runner,
395+
example_prompts,
396+
monkeypatch,
397+
model: str,
398+
max_tokens: int,
399+
num_logprobs: int,
400+
) -> None:
401+
402+
try:
403+
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
404+
model_info.check_available_online(on_fail="skip")
405+
model_info.check_transformers_version(on_fail="skip")
406+
except ValueError:
407+
pass
408+
409+
with hf_runner(model) as hf_model:
410+
if model not in HF_UNSUPPORTED_MODELS:
411+
hf_outputs = hf_model.generate_greedy_logprobs_limit(
412+
example_prompts, max_tokens, num_logprobs)
413+
else:
414+
hf_outputs = None
415+
416+
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
417+
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
418+
example_prompts, max_tokens, num_logprobs)
419+
420+
with monkeypatch.context() as m:
421+
m.setenv("VLLM_USE_V1", "1")
422+
if model in HYBRID_MODELS:
423+
# required due to reorder_batch behaviour
424+
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
425+
with vllm_runner(model,
426+
max_num_seqs=MAX_NUM_SEQS,
427+
compilation_config={'full_cuda_graph': True},
428+
enable_prefix_caching=False) as vllm_model:
429+
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
430+
example_prompts, max_tokens, num_logprobs)
431+
432+
if hf_outputs is not None:
433+
check_logprobs_close(
434+
outputs_0_lst=hf_outputs,
435+
outputs_1_lst=vllm_v0_outputs,
436+
name_0="hf",
437+
name_1="vllm-v0",
438+
)
439+
440+
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
441+
check_logprobs_close(
442+
outputs_0_lst=ref_outputs,
443+
outputs_1_lst=vllm_v1_outputs,
444+
name_0="hf" if hf_outputs is not None else "vllm-v0",
445+
name_1="vllm-v1",
446+
)

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import torch
88

99
from vllm.attention.backends.abstract import AttentionBackend
10+
from vllm.attention.backends.utils import PAD_SLOT_ID
1011
from vllm.config import VllmConfig
11-
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
12+
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
13+
AttentionMetadataBuilder,
1214
CommonAttentionMetadata,
1315
split_decodes_and_prefills)
1416
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@@ -82,6 +84,8 @@ class Mamba2AttentionMetadata:
8284

8385
class Mamba2AttentionMetadataBuilder(
8486
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
87+
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
88+
AttentionCGSupport.PURE_DECODE_ONLY
8589

8690
reorder_batch_threshold: ClassVar[int] = 1
8791

@@ -90,8 +94,18 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
9094
assert isinstance(kv_cache_spec, MambaSpec)
9195
self.kv_cache_spec = kv_cache_spec
9296
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
97+
self.vllm_config = vllm_config
98+
self.compilation_config = vllm_config.compilation_config
9399
assert self.chunk_size is not None, (
94100
"chunk_size needs to be set in the model config for Mamba2 models")
101+
self.decode_cudagraph_max_bs = min(
102+
self.vllm_config.scheduler_config.max_num_seqs,
103+
self.compilation_config.max_capture_size)
104+
self.state_indices_tensor = torch.empty(
105+
(self.decode_cudagraph_max_bs, ),
106+
dtype=torch.int32,
107+
device=device,
108+
)
95109

96110
def build(self,
97111
common_prefix_len: int,
@@ -144,6 +158,14 @@ def build(self,
144158
query_start_loc_p, self.chunk_size,
145159
num_prefill_tokens))
146160

161+
elif num_decodes <= self.decode_cudagraph_max_bs:
162+
# Pad state tensor for CUDA graph
163+
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
164+
self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor,
165+
non_blocking=True)
166+
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
167+
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
168+
147169
attn_metadata = Mamba2AttentionMetadata(
148170
num_prefills=num_prefills,
149171
num_prefill_tokens=num_prefill_tokens,
@@ -160,3 +182,23 @@ def build(self,
160182
state_indices_tensor=state_indices_tensor,
161183
)
162184
return attn_metadata
185+
186+
def build_for_cudagraph_capture(
187+
self, common_attn_metadata: CommonAttentionMetadata):
188+
"""
189+
This method builds the metadata for full cudagraph capture.
190+
Currently, only decode is supported for full cudagraphs with Mamba.
191+
"""
192+
m = common_attn_metadata
193+
194+
assert m.num_reqs == m.num_actual_tokens, \
195+
"Mamba only supports decode-only full CUDAGraph capture. " \
196+
"Make sure all cudagraph capture sizes <= max_num_seq."
197+
198+
m.max_query_len = 1 # decode-only
199+
200+
return self.build(0, m)
201+
202+
def can_run_in_cudagraph(
203+
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
204+
return common_attn_metadata.max_query_len == 1

0 commit comments

Comments
 (0)