@@ -657,7 +657,6 @@ def _gather_kv_cache(
657657 block_tables , # (batch_size, max_blocks_per_seq)
658658 block_table_stride ,
659659 kv_cache , # (num_blocks, block_size, head_size)
660- kv_page_stride ,
661660 kv_out ,
662661 CACHE_PAGE_SIZE : tl .constexpr ,
663662 CACHE_ENTRY_SIZE : tl .constexpr ,
@@ -684,16 +683,17 @@ def _gather_kv_cache(
684683 cache_page_mask = cache_page_range < CACHE_PAGE_SIZE
685684 for i in range (pages_to_copy - 1 ):
686685 page = tl .load (block_table + i )
687- page_start = kv_cache + page * kv_page_stride
686+ page_start = kv_cache + page * CACHE_PAGE_SIZE
688687 page_data = tl .load (page_start + cache_page_range ,
689688 mask = cache_page_mask )
690689 tl .store (kv_out + i * CACHE_PAGE_SIZE + cache_page_range ,
691690 page_data ,
692691 mask = cache_page_mask )
693692
694- last_page_len = seq_len % CACHE_ENTRIES_PER_PAGE
693+ last_page_len = (seq_len + CACHE_ENTRIES_PER_PAGE -
694+ 1 ) % CACHE_ENTRIES_PER_PAGE + 1
695695 last_page = tl .load (block_table + pages_to_copy - 1 )
696- last_page_start = kv_cache + last_page * kv_page_stride
696+ last_page_start = kv_cache + last_page * CACHE_PAGE_SIZE
697697
698698 cache_entry_range = tl .arange (0 , CACHE_ENTRY_SIZE_POW_2 )
699699 cache_entry_mask = cache_entry_range < CACHE_ENTRY_SIZE
@@ -753,37 +753,62 @@ def _forward_prefill(
753753 ) -> torch .Tensor :
754754 assert isinstance (attn_metadata , TritonMLAMetadata )
755755
756- if attn_metadata .prefill_metadata .context_lens_tensor is not None and \
757- max (attn_metadata .prefill_metadata .context_lens_tensor ) > 0 :
758- entries_total = attn_metadata .prefill_metadata .seq_start_loc [- 1 ]
759- kv_c_k_pe_cache = torch .empty (
756+ prefill_meta = attn_metadata .prefill_metadata
757+ assert prefill_meta is not None
758+
759+ if kv_c_and_k_pe_cache .numel () > 0 and \
760+ prefill_meta .block_tables is not None and \
761+ prefill_meta .block_tables .numel () > 0 :
762+ assert prefill_meta .seq_start_loc is not None
763+ assert prefill_meta .max_query_len is not None
764+
765+ entries_total = prefill_meta .seq_start_loc [- 1 ]
766+ kv_c_k_pe_cache = torch .empty_strided (
760767 (entries_total , kv_c_and_k_pe_cache .shape [- 1 ]),
768+ (kv_c_and_k_pe_cache .stride (1 ), 1 ),
761769 dtype = kv_c_and_k_pe_cache .dtype ,
762770 device = kv_c_and_k_pe_cache .device ,
763771 )
764772
765773 assert kv_c_and_k_pe_cache .shape [- 1 ] == 576
766774 assert kv_c_and_k_pe_cache .shape [- 2 ] == 16
767775 _gather_kv_cache [(attn_metadata .num_prefills , )](
768- attn_metadata . prefill_metadata .seq_start_loc ,
769- attn_metadata . prefill_metadata .block_tables ,
770- attn_metadata . prefill_metadata .block_tables .stride (0 ),
776+ prefill_meta .seq_start_loc ,
777+ prefill_meta .block_tables ,
778+ prefill_meta .block_tables .stride (0 ),
771779 kv_c_and_k_pe_cache ,
772- kv_c_and_k_pe_cache .stride (0 ),
773780 kv_c_k_pe_cache ,
774- CACHE_PAGE_SIZE = 576 * 16 ,
775- CACHE_ENTRY_SIZE = 576 ,
776- CACHE_ENTRIES_PER_PAGE = 16 ,
777- CACHE_ENTRY_SIZE_POW_2 = triton .next_power_of_2 (576 ),
778- CACHE_PAGE_SIZE_POW_2 = triton .next_power_of_2 (576 * 16 ),
781+ CACHE_PAGE_SIZE = kv_c_and_k_pe_cache .stride (0 ),
782+ CACHE_ENTRY_SIZE = kv_c_and_k_pe_cache .stride (1 ),
783+ CACHE_ENTRIES_PER_PAGE = kv_c_and_k_pe_cache .shape [1 ],
784+ CACHE_ENTRY_SIZE_POW_2 = triton .next_power_of_2 (
785+ kv_c_and_k_pe_cache .stride (1 )),
786+ CACHE_PAGE_SIZE_POW_2 = triton .next_power_of_2 (
787+ kv_c_and_k_pe_cache .stride (0 )),
779788 )
780789
781- kv_c = kv_c_k_pe_cache [..., :self .kv_lora_rank ].unsqueeze (1 )
782- k_pe = kv_c_k_pe_cache [..., self .kv_lora_rank :].unsqueeze (1 )
783-
784- return self ._forward_prefill_flash (q , kv_c , k_pe ,
785- attn_metadata .seq_start_loc ,
786- attn_metadata .max_prefill_seq_len )
790+ kv_c = kv_c_k_pe_cache [..., :self .kv_lora_rank ].unsqueeze (
791+ 1 ).contiguous ()
792+ k_pe = kv_c_k_pe_cache [..., self .kv_lora_rank :].unsqueeze (
793+ 1 ).contiguous ()
794+
795+ return self ._forward_prefill_flash (
796+ q ,
797+ kv_c ,
798+ k_pe ,
799+ seq_start_loc = prefill_meta .seq_start_loc ,
800+ max_prefill_seq_len = prefill_meta .max_prefill_seq_len ,
801+ query_start_loc = prefill_meta .query_start_loc ,
802+ max_query_len = prefill_meta .max_query_len )
803+ else :
804+ return self ._forward_prefill_flash (
805+ q ,
806+ kv_c ,
807+ k_pe ,
808+ seq_start_loc = prefill_meta .seq_start_loc ,
809+ max_prefill_seq_len = prefill_meta .max_prefill_seq_len ,
810+ query_start_loc = prefill_meta .seq_start_loc ,
811+ max_query_len = prefill_meta .max_prefill_seq_len )
787812
788813 def _forward_decode (
789814 self ,
0 commit comments