Skip to content

Commit fceafaf

Browse files
authored
[Bugfix][mamba] Fix type annotation of Mamba2Metadata (#22787)
Signed-off-by: Chen Zhang <[email protected]>
1 parent 6b794c7 commit fceafaf

File tree

2 files changed

+26
-21
lines changed

2 files changed

+26
-21
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -473,12 +473,12 @@ def forward_cuda(
473473
conv_state = self_kv_cache[0].transpose(-1, -2)
474474
ssm_state = self_kv_cache[1]
475475
state_indices_tensor = attn_metadata.state_indices_tensor
476-
has_initial_states_p = attn_metadata.has_initial_states
476+
has_initial_states_p = attn_metadata.has_initial_states_p
477477
prep_initial_states = attn_metadata.prep_initial_states
478478
chunk_size = attn_metadata.chunk_size
479-
seq_idx_p = attn_metadata.seq_idx
480-
chunk_indices_p = attn_metadata.chunk_indices
481-
chunk_offsets_p = attn_metadata.chunk_offsets
479+
seq_idx_p = attn_metadata.seq_idx_p
480+
chunk_indices_p = attn_metadata.chunk_indices_p
481+
chunk_offsets_p = attn_metadata.chunk_offsets_p
482482
else:
483483
conv_state = mamba_cache_params.conv_state
484484
ssm_state = mamba_cache_params.ssm_state

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,19 @@ class Mamba2AttentionMetadata:
6868
query_start_loc: torch.Tensor
6969
seq_lens: torch.Tensor
7070

71-
has_initial_states: torch.Tensor
7271
prep_initial_states: bool
7372
chunk_size: int
74-
seq_idx: torch.Tensor
75-
chunk_indices: torch.Tensor
76-
chunk_offsets: torch.Tensor
73+
74+
# The following tensors only contain prefill requests and will be None if
75+
# the batch has no prefill request.
76+
has_initial_states_p: Optional[torch.Tensor]
77+
seq_idx_p: Optional[torch.Tensor]
78+
chunk_indices_p: Optional[torch.Tensor]
79+
chunk_offsets_p: Optional[torch.Tensor]
7780

7881
state_indices_tensor: torch.Tensor # shape: [batch,]
82+
83+
# The following attributes are for triton implementation of causal_conv1d
7984
nums_dict: Optional[dict] = None
8085
cu_seqlen: Optional[int] = None
8186
batch_ptr: Optional[torch.tensor] = None
@@ -115,11 +120,11 @@ def build(self,
115120
query_start_loc = common_attn_metadata.query_start_loc
116121
seq_lens = common_attn_metadata.seq_lens
117122

118-
seq_idx = None
119-
chunk_indices, chunk_offsets = None, None
123+
seq_idx_p = None
124+
chunk_indices_p, chunk_offsets_p = None, None
120125
# Need flags to indicate if there are initial states
121126
# currently we really only support the FlashAttention backend
122-
has_initial_states = None
127+
has_initial_states_p = None
123128
prep_initial_states = False
124129

125130
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
@@ -135,25 +140,25 @@ def build(self,
135140
common_attn_metadata.
136141
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
137142
prep_initial_states = torch.any(has_initial_states_cpu).item()
138-
has_initial_states = has_initial_states_cpu.to(
143+
has_initial_states_p = has_initial_states_cpu.to(
139144
query_start_loc.device)
140145

141146
query_start_loc_p = common_attn_metadata.query_start_loc[
142147
-num_prefills - 1:] - num_decode_tokens
143148

144-
seq_idx = torch.repeat_interleave(torch.arange(
149+
seq_idx_p = torch.repeat_interleave(torch.arange(
145150
num_prefills,
146151
dtype=torch.int32,
147152
device=query_start_loc_p.device),
148-
query_start_loc_p.diff(),
149-
output_size=num_prefill_tokens)
150-
seq_idx.unsqueeze_(0)
153+
query_start_loc_p.diff(),
154+
output_size=num_prefill_tokens)
155+
seq_idx_p.unsqueeze_(0)
151156

152157
# We compute metadata for chunked prefill once at the top level
153158
# model forward and reuse them in mamba layers. If not needed,
154159
# they will be ignored inside mamba kernels.
155160
if prep_initial_states:
156-
chunk_indices, chunk_offsets = (
161+
chunk_indices_p, chunk_offsets_p = (
157162
_query_start_loc_to_chunk_indices_offsets(
158163
query_start_loc_p, self.chunk_size,
159164
num_prefill_tokens))
@@ -173,12 +178,12 @@ def build(self,
173178
num_decode_tokens=num_decode_tokens,
174179
query_start_loc=query_start_loc,
175180
seq_lens=seq_lens,
176-
has_initial_states=has_initial_states,
177181
prep_initial_states=prep_initial_states,
178182
chunk_size=self.chunk_size,
179-
seq_idx=seq_idx,
180-
chunk_indices=chunk_indices,
181-
chunk_offsets=chunk_offsets,
183+
has_initial_states_p=has_initial_states_p,
184+
seq_idx_p=seq_idx_p,
185+
chunk_indices_p=chunk_indices_p,
186+
chunk_offsets_p=chunk_offsets_p,
182187
state_indices_tensor=state_indices_tensor,
183188
)
184189
return attn_metadata

0 commit comments

Comments
 (0)