7
7
import torch
8
8
9
9
from vllm .attention .backends .abstract import AttentionBackend
10
+ from vllm .attention .backends .utils import PAD_SLOT_ID
10
11
from vllm .config import VllmConfig
11
- from vllm .v1 .attention .backends .utils import (AttentionMetadataBuilder ,
12
+ from vllm .v1 .attention .backends .utils import (AttentionCGSupport ,
13
+ AttentionMetadataBuilder ,
12
14
CommonAttentionMetadata ,
13
15
split_decodes_and_prefills )
14
16
from vllm .v1 .kv_cache_interface import AttentionSpec , MambaSpec
@@ -82,6 +84,8 @@ class Mamba2AttentionMetadata:
82
84
83
85
class Mamba2AttentionMetadataBuilder (
84
86
AttentionMetadataBuilder [Mamba2AttentionMetadata ]):
87
+ attn_cudagraph_support : ClassVar [AttentionCGSupport ] = \
88
+ AttentionCGSupport .PURE_DECODE_ONLY
85
89
86
90
reorder_batch_threshold : ClassVar [int ] = 1
87
91
@@ -90,8 +94,18 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
90
94
assert isinstance (kv_cache_spec , MambaSpec )
91
95
self .kv_cache_spec = kv_cache_spec
92
96
self .chunk_size = vllm_config .model_config .get_mamba_chunk_size ()
97
+ self .vllm_config = vllm_config
98
+ self .compilation_config = vllm_config .compilation_config
93
99
assert self .chunk_size is not None , (
94
100
"chunk_size needs to be set in the model config for Mamba2 models" )
101
+ self .decode_cudagraph_max_bs = min (
102
+ self .vllm_config .scheduler_config .max_num_seqs ,
103
+ self .compilation_config .max_capture_size )
104
+ self .state_indices_tensor = torch .empty (
105
+ (self .decode_cudagraph_max_bs , ),
106
+ dtype = torch .int32 ,
107
+ device = device ,
108
+ )
95
109
96
110
def build (self ,
97
111
common_prefix_len : int ,
@@ -144,6 +158,14 @@ def build(self,
144
158
query_start_loc_p , self .chunk_size ,
145
159
num_prefill_tokens ))
146
160
161
+ elif num_decodes <= self .decode_cudagraph_max_bs :
162
+ # Pad state tensor for CUDA graph
163
+ num_input_tokens = self .vllm_config .pad_for_cudagraph (num_decodes )
164
+ self .state_indices_tensor [:num_decodes ].copy_ (state_indices_tensor ,
165
+ non_blocking = True )
166
+ state_indices_tensor = self .state_indices_tensor [:num_input_tokens ]
167
+ state_indices_tensor [num_decodes :] = PAD_SLOT_ID
168
+
147
169
attn_metadata = Mamba2AttentionMetadata (
148
170
num_prefills = num_prefills ,
149
171
num_prefill_tokens = num_prefill_tokens ,
@@ -160,3 +182,23 @@ def build(self,
160
182
state_indices_tensor = state_indices_tensor ,
161
183
)
162
184
return attn_metadata
185
+
186
+ def build_for_cudagraph_capture (
187
+ self , common_attn_metadata : CommonAttentionMetadata ):
188
+ """
189
+ This method builds the metadata for full cudagraph capture.
190
+ Currently, only decode is supported for full cudagraphs with Mamba.
191
+ """
192
+ m = common_attn_metadata
193
+
194
+ assert m .num_reqs == m .num_actual_tokens , \
195
+ "Mamba only supports decode-only full CUDAGraph capture. " \
196
+ "Make sure all cudagraph capture sizes <= max_num_seq."
197
+
198
+ m .max_query_len = 1 # decode-only
199
+
200
+ return self .build (0 , m )
201
+
202
+ def can_run_in_cudagraph (
203
+ self , common_attn_metadata : CommonAttentionMetadata ) -> bool :
204
+ return common_attn_metadata .max_query_len == 1
0 commit comments