Skip to content

Commit 778f554

Browse files
authored
[V1] [Hybrid] Some additional clean-up in Mamba2 prefix caching (#26222)
Signed-off-by: Thomas Parnell <[email protected]>
1 parent d3c8429 commit 778f554

File tree

4 files changed

+171
-136
lines changed

4 files changed

+171
-136
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 96 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -595,21 +595,32 @@ def forward_cuda(
595595
if prefix_caching_enabled:
596596
# If prefix caching is enabled, retrieve the relevant variables
597597
# for prefill and decode
598-
last_state_idx_d, last_state_idx_p = torch.split(
599-
attn_metadata.last_state_idx, [num_decodes, num_prefills], dim=0
598+
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
599+
torch.split(
600+
attn_metadata.block_idx_last_computed_token,
601+
[num_decodes, num_prefills],
602+
dim=0,
603+
)
600604
)
601-
current_last_idx_d, current_last_idx_p = torch.split(
602-
attn_metadata.current_last_idx, [num_decodes, num_prefills], dim=0
605+
block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = (
606+
torch.split(
607+
attn_metadata.block_idx_last_scheduled_token,
608+
[num_decodes, num_prefills],
609+
dim=0,
610+
)
603611
)
604612
# Prefill-only variables:
605-
current_first_idx_p = attn_metadata.current_first_idx_p
606-
context_lens_p = attn_metadata.context_lens_p
607-
last_computed_offset_p = attn_metadata.last_computed_offset_p
613+
block_idx_first_scheduled_token_p = (
614+
attn_metadata.block_idx_first_scheduled_token_p
615+
)
616+
num_computed_tokens_p = attn_metadata.num_computed_tokens_p
608617
else:
609-
last_state_idx_d, last_state_idx_p = None, None
610-
current_last_idx_d, current_last_idx_p = None, None
611-
current_first_idx_p = None
612-
context_lens_p = None
618+
block_idx_last_computed_token_d = None
619+
block_idx_last_computed_token_p = None
620+
block_idx_last_scheduled_token_d = None
621+
block_idx_last_scheduled_token_p = None
622+
block_idx_first_scheduled_token_p = None
623+
num_computed_tokens_p = None
613624

614625
# Preallocate output tensor to avoid memcpy cost for merging prefill
615626
# and decode outputs
@@ -637,7 +648,8 @@ def forward_cuda(
637648
# to by "state_indices_tensor_p".
638649
# In particular, it will always write the state at the
639650
# sequence end.
640-
# In addition, "current_first_idx_p" and "current_last_idx_p"
651+
# In addition, "block_idx_first_scheduled_token_p" and
652+
# "block_idx_last_scheduled_token_p"
641653
# are provided (which are pointers into
642654
# "state_indices_tensor_p"), it will write additional cache
643655
# states aligned at "block_size_to_align".
@@ -652,10 +664,10 @@ def forward_cuda(
652664
conv_states=conv_state,
653665
has_initial_state=has_initial_states_p,
654666
cache_indices=state_indices_tensor_p,
655-
current_first_idx=current_first_idx_p,
656-
current_last_idx=current_last_idx_p,
657-
initial_state_idx=last_state_idx_p,
658-
context_lens=context_lens_p,
667+
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
668+
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
669+
initial_state_idx=block_idx_last_computed_token_p,
670+
num_computed_tokens=num_computed_tokens_p,
659671
block_size_to_align=mamba_block_size,
660672
metadata=attn_metadata,
661673
query_start_loc=query_start_loc_p,
@@ -669,7 +681,7 @@ def forward_cuda(
669681
kernel_ssm_indices = state_indices_tensor_p
670682
if prefix_caching_enabled:
671683
kernel_ssm_indices = state_indices_tensor_p.gather(
672-
1, last_state_idx_p.unsqueeze(1)
684+
1, block_idx_last_computed_token_p.unsqueeze(1)
673685
).squeeze(1)
674686
initial_states = torch.where(
675687
has_initial_states_p[:, None, None, None],
@@ -703,52 +715,76 @@ def forward_cuda(
703715
)
704716

705717
if prefix_caching_enabled:
706-
# Save states for sequences with more than just the final state:
707-
n_blocks_to_fill = current_last_idx_p - current_first_idx_p
708-
for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1):
718+
# The chunk_stride is the number of chunks per mamba block
719+
# e.g., if mamba_block_size = 512 and chunk_size = 256,
720+
# then chunk_stride = 2
721+
chunk_stride = mamba_block_size // chunk_size
722+
723+
# Save state for sequences with more than just final state
724+
for seq_idx in range(num_prefills):
725+
# Block index for the first scheduled token
726+
block_idx_first_scheduled_token = block_idx_first_scheduled_token_p[
727+
seq_idx
728+
]
729+
730+
# Block index for the last scheduled token
731+
block_idx_last_scheduled_token = block_idx_last_scheduled_token_p[
732+
seq_idx
733+
]
734+
735+
# Number of blocks that need to be written
736+
n_blocks_to_fill = (
737+
block_idx_last_scheduled_token - block_idx_first_scheduled_token
738+
)
739+
740+
# Skip sequences that don't have any blocks to fill
741+
if n_blocks_to_fill == 0:
742+
continue
743+
744+
# Look up the state indices
709745
cache_blocks_to_fill = state_indices_tensor_p[
710746
seq_idx,
711-
current_first_idx_p[seq_idx] : current_first_idx_p[seq_idx]
712-
+ n_blocks_to_fill[seq_idx],
747+
block_idx_first_scheduled_token:block_idx_last_scheduled_token,
713748
]
714-
# chunks = [0 1 2 3 4 5 6 ...]
715-
# First aligned chunk would typically be:
716-
# mamba_block_size = 1024, chunk_size = 256
717-
# 1024 // 256 - 1 --> chunks[3]
718-
# But when last chunk wasn't block aligned:
719-
# - last_computed_offset_p[seq_idx] // chunk_size
720-
# e.g. 1000 // 256 -> 3 completed --> store chunk[0]
721-
# e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
722-
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
723-
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
724-
chunk_stride = mamba_block_size // chunk_size
725-
first_aligned_chunk = (
726-
torch.concat(
727-
[
728-
torch.zeros(
729-
1,
730-
dtype=last_chunk_indices_p.dtype,
731-
device=last_chunk_indices_p.device,
732-
),
733-
last_chunk_indices_p + 1,
734-
]
735-
)[seq_idx]
736-
+ chunk_stride
737-
- 1
738-
- last_computed_offset_p[seq_idx] // chunk_size
749+
750+
# First chunk index for this sequence
751+
if seq_idx == 0:
752+
first_chunk = 0
753+
else:
754+
first_chunk = 1 + last_chunk_indices_p[seq_idx - 1]
755+
756+
# First chunk that is aligned on the mamba block boundary
757+
first_aligned_chunk = first_chunk + chunk_stride - 1
758+
759+
# Calculate the number of computed tokens that were not
760+
# already cached
761+
num_unaligned_computed_tokens = (
762+
num_computed_tokens_p[seq_idx] % mamba_block_size
739763
)
764+
765+
if num_unaligned_computed_tokens > 0:
766+
# If the number of computed tokens is not block aligned,
767+
# then we need to shift the index accordingly
768+
first_aligned_chunk -= (
769+
num_unaligned_computed_tokens // chunk_size
770+
)
771+
772+
# Get states to write
740773
from_where = varlen_states[
741774
first_aligned_chunk : first_aligned_chunk
742-
+ n_blocks_to_fill[seq_idx] * chunk_stride : chunk_stride
775+
+ n_blocks_to_fill * chunk_stride : chunk_stride
743776
]
777+
778+
# Write the states
744779
ssm_state[cache_blocks_to_fill] = from_where
745780

746-
# For all seqs, store the last state (Note: might be partial):
781+
# For all seqs, store the last state (note: might be partial):
747782
ssm_state[
748783
state_indices_tensor_p.gather(
749-
1, current_last_idx_p.unsqueeze(1)
784+
1, block_idx_last_scheduled_token_p.unsqueeze(1)
750785
).squeeze(1)
751786
] = varlen_states[last_chunk_indices_p]
787+
752788
else:
753789
# update ssm states
754790
# - varlen state is a (num_prefills, nheads, headdim, dstate)
@@ -759,14 +795,17 @@ def forward_cuda(
759795
if has_decode:
760796
if prefix_caching_enabled:
761797
state_indices_tensor_d_input = state_indices_tensor_d.gather(
762-
1, last_state_idx_d.unsqueeze(1)
798+
1, block_idx_last_computed_token_d.unsqueeze(1)
763799
).squeeze(1)
764800
state_indices_tensor_d_output = state_indices_tensor_d.gather(
765-
1, current_last_idx_d.unsqueeze(1)
801+
1, block_idx_last_scheduled_token_d.unsqueeze(1)
766802
).squeeze(1)
767-
# Note:
768-
# for decode always: current_first_idx_d == current_last_idx_d
769-
# at block boundaries: current_first_idx_d > last_state_idx_d
803+
# for decode:
804+
# block_idx_first_scheduled_token_d ==
805+
# block_idx_last_scheduled_token_d
806+
# at block boundaries:
807+
# block_idx_first_scheduled_token_d >
808+
# block_idx_last_computed_token_d
770809
else:
771810
# Without caching, read and write in-place to the same blocks:
772811
state_indices_tensor_d_input = state_indices_tensor_d
@@ -780,8 +819,8 @@ def forward_cuda(
780819
self.conv1d.bias,
781820
self.activation,
782821
conv_state_indices=state_indices_tensor_d,
783-
current_last_idx=current_last_idx_d,
784-
initial_state_idx=last_state_idx_d,
822+
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
823+
initial_state_idx=block_idx_last_computed_token_d,
785824
)
786825

787826
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)

vllm/model_executor/layers/mamba/ops/causal_conv1d.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ def _causal_conv1d_fwd_kernel( # continuous batching
2727
query_start_loc_ptr,
2828
batch_ptr,
2929
token_chunk_offset_ptr,
30-
current_first_idx, # (batch,)
31-
current_last_idx, # (batch,)
30+
block_idx_first_scheduled_token, # (batch,)
31+
block_idx_last_scheduled_token, # (batch,)
3232
initial_state_idx, # (batch,)
33-
context_lens, # (batch,)
33+
num_computed_tokens, # (batch,)
3434
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
3535
# Matrix dimensions
3636
dim: tl.constexpr,
@@ -94,9 +94,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
9494
# In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr"
9595

9696
# Get the length of the completed sequence so far and compute the offset.
97-
current_first_index = tl.load(current_first_idx + idx_seq)
98-
current_last_index = tl.load(current_last_idx + idx_seq)
99-
sequence_completed_index = tl.load(context_lens + idx_seq)
97+
current_first_index = tl.load(block_idx_first_scheduled_token + idx_seq)
98+
current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq)
99+
sequence_completed_index = tl.load(num_computed_tokens + idx_seq)
100100

101101
# Compute the offset where the first stride_block_m-aligned first full block is
102102
# Value in "token-space"
@@ -476,10 +476,10 @@ def causal_conv1d_fn(
476476
has_initial_state: Optional[torch.Tensor] = None,
477477
activation: Optional[str] = "silu",
478478
pad_slot_id: int = PAD_SLOT_ID,
479-
current_first_idx: Optional[torch.Tensor] = None,
480-
current_last_idx: Optional[torch.Tensor] = None,
479+
block_idx_first_scheduled_token: Optional[torch.Tensor] = None,
480+
block_idx_last_scheduled_token: Optional[torch.Tensor] = None,
481481
initial_state_idx: Optional[torch.Tensor] = None,
482-
context_lens: Optional[torch.Tensor] = None,
482+
num_computed_tokens: Optional[torch.Tensor] = None,
483483
block_size_to_align=0,
484484
metadata=None,
485485
validate_data=False,
@@ -523,13 +523,13 @@ def causal_conv1d_fn(
523523
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
524524
in this case, the kernel will not process entries at
525525
indices 0 and 3
526-
current_first_idx: (batch,), dtype int32
526+
block_idx_first_scheduled_token: (batch,), dtype int32
527527
The pointer into cache_indices, where the first cache block to be filled is located.
528-
current_last_idx: (batch,), dtype int32
528+
block_idx_last_scheduled_token: (batch,), dtype int32
529529
The pointer into cache_indices, where the last cache block to be filled is located.
530530
initial_state_idx: (batch,), dtype int32
531531
The pointer into cache_indices, where the cache block containing the initial state is located.
532-
context_lens: (batch,), dtype int32
532+
num_computed_tokens: (batch,), dtype int32
533533
The number of tokens already completed for each sequence
534534
block_size_to_align: int
535535
The block size to align the cached states to
@@ -708,10 +708,10 @@ def grid(META):
708708
query_start_loc,
709709
batch_ptr,
710710
token_chunk_offset_ptr,
711-
current_first_idx,
712-
current_last_idx,
711+
block_idx_first_scheduled_token,
712+
block_idx_last_scheduled_token,
713713
initial_state_idx,
714-
context_lens,
714+
num_computed_tokens,
715715
out,
716716
# Matrix dimensions
717717
dim,
@@ -735,7 +735,7 @@ def grid(META):
735735
HAS_BIAS=bias is not None,
736736
KERNEL_WIDTH=width,
737737
SILU_ACTIVATION=activation in ["silu", "swish"],
738-
IS_APC_ENABLED=current_last_idx is not None,
738+
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
739739
USE_PAD_SLOT=pad_slot_id is not None,
740740
NP2_STATELEN=np2_statelen,
741741
# launch_cooperative_grid=True
@@ -756,7 +756,7 @@ def _causal_conv1d_update_kernel(
756756
conv_state_indices_ptr,
757757
num_accepted_tokens_ptr,
758758
query_start_loc_ptr, # (batch + 1)
759-
current_last_idx, # (batch,)
759+
block_idx_last_scheduled_token, # (batch,)
760760
initial_state_idx, # (batch,)
761761
o_ptr, # (batch, dim, seqlen)
762762
# Matrix dimensions
@@ -802,7 +802,7 @@ def _causal_conv1d_update_kernel(
802802
if IS_APC_ENABLED:
803803
# Get the state from the initial_state_idx
804804
conv_state_init = tl.load(initial_state_idx + idx_seq)
805-
current_last_index = tl.load(current_last_idx + idx_seq)
805+
current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq)
806806
else:
807807
conv_state_init = 0
808808
current_last_index = 0
@@ -1078,7 +1078,7 @@ def causal_conv1d_update(
10781078
query_start_loc: Optional[torch.Tensor] = None,
10791079
max_query_len: int = -1,
10801080
pad_slot_id: int = PAD_SLOT_ID,
1081-
current_last_idx: Optional[torch.Tensor] = None,
1081+
block_idx_last_scheduled_token: Optional[torch.Tensor] = None,
10821082
initial_state_idx: Optional[torch.Tensor] = None,
10831083
validate_data=False,
10841084
):
@@ -1097,7 +1097,7 @@ def causal_conv1d_update(
10971097
If not None, the conv_state is a larger tensor along the batch dim,
10981098
and we are selecting the batch coords specified by conv_state_indices.
10991099
Useful for a continuous batching scenario.
1100-
current_last_idx: (batch,), dtype int32
1100+
block_idx_last_scheduled_token: (batch,), dtype int32
11011101
The pointer into conv_state_indices, where the last cache block to be filled is located.
11021102
initial_state_idx: (batch,), dtype int32
11031103
The pointer into conv_state_indices, where the cache block containing the initial state is located.
@@ -1201,7 +1201,7 @@ def grid(META):
12011201
conv_state_indices,
12021202
num_accepted_tokens,
12031203
query_start_loc,
1204-
current_last_idx,
1204+
block_idx_last_scheduled_token,
12051205
initial_state_idx,
12061206
out,
12071207
# Matrix dimensions
@@ -1230,7 +1230,7 @@ def grid(META):
12301230
KERNEL_WIDTH=width,
12311231
SILU_ACTIVATION=activation in ["silu", "swish"],
12321232
IS_VARLEN=query_start_loc is not None,
1233-
IS_APC_ENABLED=current_last_idx is not None,
1233+
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
12341234
IS_SPEC_DECODING=num_accepted_tokens is not None,
12351235
NP2_STATELEN=np2_statelen,
12361236
USE_PAD_SLOT=pad_slot_id is not None,

0 commit comments

Comments
 (0)