Skip to content

Commit 65571ba

Browse files
yessen-deepinfrapathorn
authored andcommitted
Implement chunked prefill for MLA
Fix alignment and modulo error in remainder for _gather_kv_cache Only use block_tables on chunked prefill Allow enabling chunked prefill Signed-off-by: Patrick Reiter Horn <patrick.horn@gmail.com> Signed-off-by: Yessen Kanapin <yessen@deepinfra.com>
1 parent 8354033 commit 65571ba

File tree

3 files changed

+84
-51
lines changed

3 files changed

+84
-51
lines changed

vllm/attention/backends/mla/utils.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
409409
def _forward_prefill(
410410
self,
411411
q: torch.Tensor,
412-
kv_c_normed: torch.Tensor,
412+
kv_c: torch.Tensor,
413413
k_pe: torch.Tensor,
414+
kv_c_and_k_pe_cache: torch.Tensor,
414415
attn_metadata: T,
415416
) -> torch.Tensor:
416417
raise NotImplementedError
@@ -446,22 +447,25 @@ def forward(
446447
k_pe = k_pe.unsqueeze(1)
447448
assert hasattr(attn_metadata, "input_positions")
448449

450+
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
451+
449452
if is_decode:
450-
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
451-
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
453+
decode_q_nope = self._q_proj_and_k_up_proj(
454+
hidden_states_or_q_c[num_prefill_tokens:])
455+
decode_q_pe = torch.matmul(hidden_states_or_q_c[num_prefill_tokens:], self.W_QR)\
452456
.view(-1, self.num_heads, self.qk_rope_head_dim)
453-
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
454-
k_pe)
455-
else:
456-
assert is_prefill
457-
q = self.q_proj(hidden_states_or_q_c)[0]\
457+
decode_q_pe, k_pe[num_prefill_tokens:] = \
458+
self.rotary_emb(attn_metadata.input_positions[num_prefill_tokens:],
459+
decode_q_pe, k_pe[num_prefill_tokens:])
460+
if is_prefill:
461+
prefill_q = self.q_proj(hidden_states_or_q_c[:num_prefill_tokens])[0]\
458462
.view(-1, self.num_heads, self.qk_head_dim)
459463

460464
# TODO(lucas): there must be a nicer way to write this line
461-
q[..., self.qk_nope_head_dim:], k_pe = \
465+
prefill_q[..., self.qk_nope_head_dim:], k_pe[:num_prefill_tokens] = \
462466
self.rotary_emb(
463-
attn_metadata.input_positions,
464-
q[..., self.qk_nope_head_dim:], k_pe)
467+
attn_metadata.input_positions[:num_prefill_tokens],
468+
prefill_q[..., self.qk_nope_head_dim:], k_pe[:num_prefill_tokens])
465469

466470
# write the latent and rope to kv cache
467471
if kv_cache.numel() > 0:
@@ -473,13 +477,25 @@ def forward(
473477
kv_cache_dtype=self.kv_cache_dtype,
474478
scale=layer._k_scale,
475479
)
480+
output = torch.empty(attn_metadata.num_prefill_tokens +
481+
attn_metadata.num_decode_tokens,
482+
self.o_proj.output_size,
483+
device=hidden_states_or_q_c.device,
484+
dtype=hidden_states_or_q_c.dtype)
485+
# output shape: [2048, 16, 512]
486+
487+
if is_prefill:
488+
# forward prefill output shape: [2048, 7168]
489+
output[:num_prefill_tokens] = self._forward_prefill(
490+
prefill_q, k_c_normed[:num_prefill_tokens].contiguous(),
491+
k_pe[:num_prefill_tokens].contiguous(), kv_cache,
492+
attn_metadata)
476493

477-
if attn_metadata.prefill_metadata is not None:
478-
return self._forward_prefill(q, k_c_normed, k_pe, kv_cache,
479-
attn_metadata)
494+
if is_decode:
495+
output[num_prefill_tokens:] = self._forward_decode(
496+
decode_q_nope, decode_q_pe, kv_cache, attn_metadata)
480497

481-
if attn_metadata.decode_metadata is not None:
482-
return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata)
498+
return output
483499

484500
# Optional common flash-attn based prefill
485501
def _forward_prefill_flash(
@@ -489,6 +505,8 @@ def _forward_prefill_flash(
489505
k_pe: torch.Tensor,
490506
seq_start_loc: torch.Tensor,
491507
max_prefill_seq_len: int,
508+
query_start_loc: torch.Tensor,
509+
max_query_len: int,
492510
) -> torch.Tensor:
493511

494512
kv_nope = self.kv_b_proj(k_c_normed)[0]\
@@ -507,9 +525,9 @@ def _forward_prefill_flash(
507525
q=q,
508526
k=k,
509527
v=v_padded,
510-
cu_seqlens_q=seq_start_loc,
528+
cu_seqlens_q=query_start_loc,
511529
cu_seqlens_k=seq_start_loc,
512-
max_seqlen_q=max_prefill_seq_len,
530+
max_seqlen_q=max_query_len,
513531
max_seqlen_k=max_prefill_seq_len,
514532
softmax_scale=self.scale,
515533
causal=True,

vllm/attention/backends/triton_mla.py

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,6 @@ def _gather_kv_cache(
657657
block_tables, # (batch_size, max_blocks_per_seq)
658658
block_table_stride,
659659
kv_cache, # (num_blocks, block_size, head_size)
660-
kv_page_stride,
661660
kv_out,
662661
CACHE_PAGE_SIZE: tl.constexpr,
663662
CACHE_ENTRY_SIZE: tl.constexpr,
@@ -684,16 +683,17 @@ def _gather_kv_cache(
684683
cache_page_mask = cache_page_range < CACHE_PAGE_SIZE
685684
for i in range(pages_to_copy - 1):
686685
page = tl.load(block_table + i)
687-
page_start = kv_cache + page * kv_page_stride
686+
page_start = kv_cache + page * CACHE_PAGE_SIZE
688687
page_data = tl.load(page_start + cache_page_range,
689688
mask=cache_page_mask)
690689
tl.store(kv_out + i * CACHE_PAGE_SIZE + cache_page_range,
691690
page_data,
692691
mask=cache_page_mask)
693692

694-
last_page_len = seq_len % CACHE_ENTRIES_PER_PAGE
693+
last_page_len = (seq_len + CACHE_ENTRIES_PER_PAGE -
694+
1) % CACHE_ENTRIES_PER_PAGE + 1
695695
last_page = tl.load(block_table + pages_to_copy - 1)
696-
last_page_start = kv_cache + last_page * kv_page_stride
696+
last_page_start = kv_cache + last_page * CACHE_PAGE_SIZE
697697

698698
cache_entry_range = tl.arange(0, CACHE_ENTRY_SIZE_POW_2)
699699
cache_entry_mask = cache_entry_range < CACHE_ENTRY_SIZE
@@ -753,37 +753,62 @@ def _forward_prefill(
753753
) -> torch.Tensor:
754754
assert isinstance(attn_metadata, TritonMLAMetadata)
755755

756-
if attn_metadata.prefill_metadata.context_lens_tensor is not None and \
757-
max(attn_metadata.prefill_metadata.context_lens_tensor) > 0:
758-
entries_total = attn_metadata.prefill_metadata.seq_start_loc[-1]
759-
kv_c_k_pe_cache = torch.empty(
756+
prefill_meta = attn_metadata.prefill_metadata
757+
assert prefill_meta is not None
758+
759+
if kv_c_and_k_pe_cache.numel() > 0 and \
760+
prefill_meta.block_tables is not None and \
761+
prefill_meta.block_tables.numel() > 0:
762+
assert prefill_meta.seq_start_loc is not None
763+
assert prefill_meta.max_query_len is not None
764+
765+
entries_total = prefill_meta.seq_start_loc[-1]
766+
kv_c_k_pe_cache = torch.empty_strided(
760767
(entries_total, kv_c_and_k_pe_cache.shape[-1]),
768+
(kv_c_and_k_pe_cache.stride(1), 1),
761769
dtype=kv_c_and_k_pe_cache.dtype,
762770
device=kv_c_and_k_pe_cache.device,
763771
)
764772

765773
assert kv_c_and_k_pe_cache.shape[-1] == 576
766774
assert kv_c_and_k_pe_cache.shape[-2] == 16
767775
_gather_kv_cache[(attn_metadata.num_prefills, )](
768-
attn_metadata.prefill_metadata.seq_start_loc,
769-
attn_metadata.prefill_metadata.block_tables,
770-
attn_metadata.prefill_metadata.block_tables.stride(0),
776+
prefill_meta.seq_start_loc,
777+
prefill_meta.block_tables,
778+
prefill_meta.block_tables.stride(0),
771779
kv_c_and_k_pe_cache,
772-
kv_c_and_k_pe_cache.stride(0),
773780
kv_c_k_pe_cache,
774-
CACHE_PAGE_SIZE=576 * 16,
775-
CACHE_ENTRY_SIZE=576,
776-
CACHE_ENTRIES_PER_PAGE=16,
777-
CACHE_ENTRY_SIZE_POW_2=triton.next_power_of_2(576),
778-
CACHE_PAGE_SIZE_POW_2=triton.next_power_of_2(576 * 16),
781+
CACHE_PAGE_SIZE=kv_c_and_k_pe_cache.stride(0),
782+
CACHE_ENTRY_SIZE=kv_c_and_k_pe_cache.stride(1),
783+
CACHE_ENTRIES_PER_PAGE=kv_c_and_k_pe_cache.shape[1],
784+
CACHE_ENTRY_SIZE_POW_2=triton.next_power_of_2(
785+
kv_c_and_k_pe_cache.stride(1)),
786+
CACHE_PAGE_SIZE_POW_2=triton.next_power_of_2(
787+
kv_c_and_k_pe_cache.stride(0)),
779788
)
780789

781-
kv_c = kv_c_k_pe_cache[..., :self.kv_lora_rank].unsqueeze(1)
782-
k_pe = kv_c_k_pe_cache[..., self.kv_lora_rank:].unsqueeze(1)
783-
784-
return self._forward_prefill_flash(q, kv_c, k_pe,
785-
attn_metadata.seq_start_loc,
786-
attn_metadata.max_prefill_seq_len)
790+
kv_c = kv_c_k_pe_cache[..., :self.kv_lora_rank].unsqueeze(
791+
1).contiguous()
792+
k_pe = kv_c_k_pe_cache[..., self.kv_lora_rank:].unsqueeze(
793+
1).contiguous()
794+
795+
return self._forward_prefill_flash(
796+
q,
797+
kv_c,
798+
k_pe,
799+
seq_start_loc=prefill_meta.seq_start_loc,
800+
max_prefill_seq_len=prefill_meta.max_prefill_seq_len,
801+
query_start_loc=prefill_meta.query_start_loc,
802+
max_query_len=prefill_meta.max_query_len)
803+
else:
804+
return self._forward_prefill_flash(
805+
q,
806+
kv_c,
807+
k_pe,
808+
seq_start_loc=prefill_meta.seq_start_loc,
809+
max_prefill_seq_len=prefill_meta.max_prefill_seq_len,
810+
query_start_loc=prefill_meta.seq_start_loc,
811+
max_query_len=prefill_meta.max_prefill_seq_len)
787812

788813
def _forward_decode(
789814
self,

vllm/config.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3264,16 +3264,6 @@ def __post_init__(self):
32643264

32653265
current_platform.check_and_update_config(self)
32663266

3267-
# If MLA is enabled, force disable chunked prefill and prefix caching
3268-
if self.model_config and self.model_config.use_mla:
3269-
logger.info("MLA is enabled; forcing chunked prefill and prefix "
3270-
"caching to be disabled.")
3271-
self.scheduler_config.enable_chunked_prefill = False
3272-
self.scheduler_config.chunked_prefill_enabled = False
3273-
3274-
if self.cache_config is not None:
3275-
self.cache_config.enable_prefix_caching = False
3276-
32773267
if not self.instance_id:
32783268
self.instance_id = random_uuid()[:5]
32793269

0 commit comments

Comments
 (0)