|
6 | 6 | from itertools import accumulate |
7 | 7 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type |
8 | 8 |
|
| 9 | +import triton |
| 10 | +import triton.language as tl |
| 11 | + |
9 | 12 | from vllm.multimodal import MultiModalPlaceholderMap |
10 | 13 |
|
11 | 14 | try: |
@@ -647,6 +650,63 @@ def build(self, seq_lens: List[int], query_lens: List[int], |
647 | 650 | ) |
648 | 651 |
|
649 | 652 |
|
| 653 | +@triton.jit |
| 654 | +def _gather_kv_cache( |
| 655 | + # Pointers to inputs and output |
| 656 | + seq_start_locs, # (batch_size + 1,) |
| 657 | + block_tables, # (batch_size, max_blocks_per_seq) |
| 658 | + block_table_stride, |
| 659 | + kv_cache, # (num_blocks, block_size, head_size) |
| 660 | + kv_page_stride, |
| 661 | + kv_out, |
| 662 | + CACHE_PAGE_SIZE: tl.constexpr, |
| 663 | + CACHE_ENTRY_SIZE: tl.constexpr, |
| 664 | + CACHE_ENTRIES_PER_PAGE: tl.constexpr, |
| 665 | + CACHE_PAGE_SIZE_POW_2: tl.constexpr, |
| 666 | + CACHE_ENTRY_SIZE_POW_2: tl.constexpr, |
| 667 | +): |
| 668 | + """A Triton-accelerated function to perform per-token-group |
| 669 | + quantization on a tensor. |
| 670 | + This function converts the tensor values into float8 values. |
| 671 | + """ |
| 672 | + |
| 673 | + # Map the program id to the row of X and Y it should compute. |
| 674 | + g_id = tl.program_id(0) |
| 675 | + |
| 676 | + seq_start_loc = tl.load(seq_start_locs + g_id) |
| 677 | + seq_len = tl.load(seq_start_locs + g_id + 1) - seq_start_loc |
| 678 | + |
| 679 | + pages_to_copy = tl.cdiv(seq_len, CACHE_ENTRIES_PER_PAGE) |
| 680 | + kv_out = kv_out + seq_start_loc * CACHE_ENTRY_SIZE |
| 681 | + block_table = block_tables + g_id * block_table_stride |
| 682 | + |
| 683 | + cache_page_range = tl.arange(0, CACHE_PAGE_SIZE_POW_2) |
| 684 | + cache_page_mask = cache_page_range < CACHE_PAGE_SIZE |
| 685 | + for i in range(pages_to_copy - 1): |
| 686 | + page = tl.load(block_table + i) |
| 687 | + page_start = kv_cache + page * kv_page_stride |
| 688 | + page_data = tl.load(page_start + cache_page_range, |
| 689 | + mask=cache_page_mask) |
| 690 | + tl.store(kv_out + i * CACHE_PAGE_SIZE + cache_page_range, |
| 691 | + page_data, |
| 692 | + mask=cache_page_mask) |
| 693 | + |
| 694 | + last_page_len = seq_len % CACHE_ENTRIES_PER_PAGE |
| 695 | + last_page = tl.load(block_table + pages_to_copy - 1) |
| 696 | + last_page_start = kv_cache + last_page * kv_page_stride |
| 697 | + |
| 698 | + cache_entry_range = tl.arange(0, CACHE_ENTRY_SIZE_POW_2) |
| 699 | + cache_entry_mask = cache_entry_range < CACHE_ENTRY_SIZE |
| 700 | + kv_out_page = kv_out + (pages_to_copy - 1) * CACHE_PAGE_SIZE |
| 701 | + for i in range(last_page_len): |
| 702 | + last_page_data = tl.load(last_page_start + \ |
| 703 | + i * CACHE_ENTRY_SIZE + cache_entry_range, |
| 704 | + mask=cache_entry_mask) |
| 705 | + tl.store(kv_out_page + i * CACHE_ENTRY_SIZE + cache_entry_range, |
| 706 | + last_page_data, |
| 707 | + mask=cache_entry_mask) |
| 708 | + |
| 709 | + |
650 | 710 | class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): |
651 | 711 |
|
652 | 712 | def __init__( |
@@ -686,12 +746,42 @@ def __init__( |
686 | 746 | def _forward_prefill( |
687 | 747 | self, |
688 | 748 | q: torch.Tensor, |
689 | | - kv_c_normed: torch.Tensor, |
| 749 | + kv_c: torch.Tensor, |
690 | 750 | k_pe: torch.Tensor, |
| 751 | + kv_c_and_k_pe_cache: torch.Tensor, |
691 | 752 | attn_metadata: TritonMLAMetadata, |
692 | 753 | ) -> torch.Tensor: |
693 | 754 | assert isinstance(attn_metadata, TritonMLAMetadata) |
694 | | - return self._forward_prefill_flash(q, kv_c_normed, k_pe, |
| 755 | + |
| 756 | + if attn_metadata.prefill_metadata.context_lens_tensor is not None and \ |
| 757 | + max(attn_metadata.prefill_metadata.context_lens_tensor) > 0: |
| 758 | + entries_total = attn_metadata.prefill_metadata.seq_start_loc[-1] |
| 759 | + kv_c_k_pe_cache = torch.empty( |
| 760 | + (entries_total, kv_c_and_k_pe_cache.shape[-1]), |
| 761 | + dtype=kv_c_and_k_pe_cache.dtype, |
| 762 | + device=kv_c_and_k_pe_cache.device, |
| 763 | + ) |
| 764 | + |
| 765 | + assert kv_c_and_k_pe_cache.shape[-1] == 576 |
| 766 | + assert kv_c_and_k_pe_cache.shape[-2] == 16 |
| 767 | + _gather_kv_cache[(attn_metadata.num_prefills, )]( |
| 768 | + attn_metadata.prefill_metadata.seq_start_loc, |
| 769 | + attn_metadata.prefill_metadata.block_tables, |
| 770 | + attn_metadata.prefill_metadata.block_tables.stride(0), |
| 771 | + kv_c_and_k_pe_cache, |
| 772 | + kv_c_and_k_pe_cache.stride(0), |
| 773 | + kv_c_k_pe_cache, |
| 774 | + CACHE_PAGE_SIZE=576 * 16, |
| 775 | + CACHE_ENTRY_SIZE=576, |
| 776 | + CACHE_ENTRIES_PER_PAGE=16, |
| 777 | + CACHE_ENTRY_SIZE_POW_2=triton.next_power_of_2(576), |
| 778 | + CACHE_PAGE_SIZE_POW_2=triton.next_power_of_2(576 * 16), |
| 779 | + ) |
| 780 | + |
| 781 | + kv_c = kv_c_k_pe_cache[..., :self.kv_lora_rank].unsqueeze(1) |
| 782 | + k_pe = kv_c_k_pe_cache[..., self.kv_lora_rank:].unsqueeze(1) |
| 783 | + |
| 784 | + return self._forward_prefill_flash(q, kv_c, k_pe, |
695 | 785 | attn_metadata.seq_start_loc, |
696 | 786 | attn_metadata.max_prefill_seq_len) |
697 | 787 |
|
|
0 commit comments