Skip to content

Commit 8354033

Browse files
LucasWilkinsonpathorn
authored andcommitted
chunked mla
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent 08b2d84 commit 8354033

File tree

2 files changed

+94
-7
lines changed

2 files changed

+94
-7
lines changed

vllm/attention/backends/mla/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -442,10 +442,6 @@ def forward(
442442
is_decode = attn_metadata.decode_metadata is not None
443443
is_prefill = attn_metadata.prefill_metadata is not None
444444

445-
if (is_decode and is_prefill):
446-
raise NotImplementedError(
447-
"chunked prefill is not supported for MLAImplBase")
448-
449445
# Restore head dim (for rotary embedding)
450446
k_pe = k_pe.unsqueeze(1)
451447
assert hasattr(attn_metadata, "input_positions")
@@ -479,7 +475,8 @@ def forward(
479475
)
480476

481477
if attn_metadata.prefill_metadata is not None:
482-
return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata)
478+
return self._forward_prefill(q, k_c_normed, k_pe, kv_cache,
479+
attn_metadata)
483480

484481
if attn_metadata.decode_metadata is not None:
485482
return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata)

vllm/attention/backends/triton_mla.py

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from itertools import accumulate
77
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
88

9+
import triton
10+
import triton.language as tl
11+
912
from vllm.multimodal import MultiModalPlaceholderMap
1013

1114
try:
@@ -647,6 +650,63 @@ def build(self, seq_lens: List[int], query_lens: List[int],
647650
)
648651

649652

653+
@triton.jit
654+
def _gather_kv_cache(
655+
# Pointers to inputs and output
656+
seq_start_locs, # (batch_size + 1,)
657+
block_tables, # (batch_size, max_blocks_per_seq)
658+
block_table_stride,
659+
kv_cache, # (num_blocks, block_size, head_size)
660+
kv_page_stride,
661+
kv_out,
662+
CACHE_PAGE_SIZE: tl.constexpr,
663+
CACHE_ENTRY_SIZE: tl.constexpr,
664+
CACHE_ENTRIES_PER_PAGE: tl.constexpr,
665+
CACHE_PAGE_SIZE_POW_2: tl.constexpr,
666+
CACHE_ENTRY_SIZE_POW_2: tl.constexpr,
667+
):
668+
"""A Triton-accelerated function to perform per-token-group
669+
quantization on a tensor.
670+
This function converts the tensor values into float8 values.
671+
"""
672+
673+
# Map the program id to the row of X and Y it should compute.
674+
g_id = tl.program_id(0)
675+
676+
seq_start_loc = tl.load(seq_start_locs + g_id)
677+
seq_len = tl.load(seq_start_locs + g_id + 1) - seq_start_loc
678+
679+
pages_to_copy = tl.cdiv(seq_len, CACHE_ENTRIES_PER_PAGE)
680+
kv_out = kv_out + seq_start_loc * CACHE_ENTRY_SIZE
681+
block_table = block_tables + g_id * block_table_stride
682+
683+
cache_page_range = tl.arange(0, CACHE_PAGE_SIZE_POW_2)
684+
cache_page_mask = cache_page_range < CACHE_PAGE_SIZE
685+
for i in range(pages_to_copy - 1):
686+
page = tl.load(block_table + i)
687+
page_start = kv_cache + page * kv_page_stride
688+
page_data = tl.load(page_start + cache_page_range,
689+
mask=cache_page_mask)
690+
tl.store(kv_out + i * CACHE_PAGE_SIZE + cache_page_range,
691+
page_data,
692+
mask=cache_page_mask)
693+
694+
last_page_len = seq_len % CACHE_ENTRIES_PER_PAGE
695+
last_page = tl.load(block_table + pages_to_copy - 1)
696+
last_page_start = kv_cache + last_page * kv_page_stride
697+
698+
cache_entry_range = tl.arange(0, CACHE_ENTRY_SIZE_POW_2)
699+
cache_entry_mask = cache_entry_range < CACHE_ENTRY_SIZE
700+
kv_out_page = kv_out + (pages_to_copy - 1) * CACHE_PAGE_SIZE
701+
for i in range(last_page_len):
702+
last_page_data = tl.load(last_page_start + \
703+
i * CACHE_ENTRY_SIZE + cache_entry_range,
704+
mask=cache_entry_mask)
705+
tl.store(kv_out_page + i * CACHE_ENTRY_SIZE + cache_entry_range,
706+
last_page_data,
707+
mask=cache_entry_mask)
708+
709+
650710
class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
651711

652712
def __init__(
@@ -686,12 +746,42 @@ def __init__(
686746
def _forward_prefill(
687747
self,
688748
q: torch.Tensor,
689-
kv_c_normed: torch.Tensor,
749+
kv_c: torch.Tensor,
690750
k_pe: torch.Tensor,
751+
kv_c_and_k_pe_cache: torch.Tensor,
691752
attn_metadata: TritonMLAMetadata,
692753
) -> torch.Tensor:
693754
assert isinstance(attn_metadata, TritonMLAMetadata)
694-
return self._forward_prefill_flash(q, kv_c_normed, k_pe,
755+
756+
if attn_metadata.prefill_metadata.context_lens_tensor is not None and \
757+
max(attn_metadata.prefill_metadata.context_lens_tensor) > 0:
758+
entries_total = attn_metadata.prefill_metadata.seq_start_loc[-1]
759+
kv_c_k_pe_cache = torch.empty(
760+
(entries_total, kv_c_and_k_pe_cache.shape[-1]),
761+
dtype=kv_c_and_k_pe_cache.dtype,
762+
device=kv_c_and_k_pe_cache.device,
763+
)
764+
765+
assert kv_c_and_k_pe_cache.shape[-1] == 576
766+
assert kv_c_and_k_pe_cache.shape[-2] == 16
767+
_gather_kv_cache[(attn_metadata.num_prefills, )](
768+
attn_metadata.prefill_metadata.seq_start_loc,
769+
attn_metadata.prefill_metadata.block_tables,
770+
attn_metadata.prefill_metadata.block_tables.stride(0),
771+
kv_c_and_k_pe_cache,
772+
kv_c_and_k_pe_cache.stride(0),
773+
kv_c_k_pe_cache,
774+
CACHE_PAGE_SIZE=576 * 16,
775+
CACHE_ENTRY_SIZE=576,
776+
CACHE_ENTRIES_PER_PAGE=16,
777+
CACHE_ENTRY_SIZE_POW_2=triton.next_power_of_2(576),
778+
CACHE_PAGE_SIZE_POW_2=triton.next_power_of_2(576 * 16),
779+
)
780+
781+
kv_c = kv_c_k_pe_cache[..., :self.kv_lora_rank].unsqueeze(1)
782+
k_pe = kv_c_k_pe_cache[..., self.kv_lora_rank:].unsqueeze(1)
783+
784+
return self._forward_prefill_flash(q, kv_c, k_pe,
695785
attn_metadata.seq_start_loc,
696786
attn_metadata.max_prefill_seq_len)
697787

0 commit comments

Comments
 (0)