@@ -595,21 +595,32 @@ def forward_cuda(
595
595
if prefix_caching_enabled :
596
596
# If prefix caching is enabled, retrieve the relevant variables
597
597
# for prefill and decode
598
- last_state_idx_d , last_state_idx_p = torch .split (
599
- attn_metadata .last_state_idx , [num_decodes , num_prefills ], dim = 0
598
+ block_idx_last_computed_token_d , block_idx_last_computed_token_p = (
599
+ torch .split (
600
+ attn_metadata .block_idx_last_computed_token ,
601
+ [num_decodes , num_prefills ],
602
+ dim = 0 ,
603
+ )
600
604
)
601
- current_last_idx_d , current_last_idx_p = torch .split (
602
- attn_metadata .current_last_idx , [num_decodes , num_prefills ], dim = 0
605
+ block_idx_last_scheduled_token_d , block_idx_last_scheduled_token_p = (
606
+ torch .split (
607
+ attn_metadata .block_idx_last_scheduled_token ,
608
+ [num_decodes , num_prefills ],
609
+ dim = 0 ,
610
+ )
603
611
)
604
612
# Prefill-only variables:
605
- current_first_idx_p = attn_metadata .current_first_idx_p
606
- context_lens_p = attn_metadata .context_lens_p
607
- last_computed_offset_p = attn_metadata .last_computed_offset_p
613
+ block_idx_first_scheduled_token_p = (
614
+ attn_metadata .block_idx_first_scheduled_token_p
615
+ )
616
+ num_computed_tokens_p = attn_metadata .num_computed_tokens_p
608
617
else :
609
- last_state_idx_d , last_state_idx_p = None , None
610
- current_last_idx_d , current_last_idx_p = None , None
611
- current_first_idx_p = None
612
- context_lens_p = None
618
+ block_idx_last_computed_token_d = None
619
+ block_idx_last_computed_token_p = None
620
+ block_idx_last_scheduled_token_d = None
621
+ block_idx_last_scheduled_token_p = None
622
+ block_idx_first_scheduled_token_p = None
623
+ num_computed_tokens_p = None
613
624
614
625
# Preallocate output tensor to avoid memcpy cost for merging prefill
615
626
# and decode outputs
@@ -637,7 +648,8 @@ def forward_cuda(
637
648
# to by "state_indices_tensor_p".
638
649
# In particular, it will always write the state at the
639
650
# sequence end.
640
- # In addition, "current_first_idx_p" and "current_last_idx_p"
651
+ # In addition, "block_idx_first_scheduled_token_p" and
652
+ # "block_idx_last_scheduled_token_p"
641
653
# are provided (which are pointers into
642
654
# "state_indices_tensor_p"), it will write additional cache
643
655
# states aligned at "block_size_to_align".
@@ -652,10 +664,10 @@ def forward_cuda(
652
664
conv_states = conv_state ,
653
665
has_initial_state = has_initial_states_p ,
654
666
cache_indices = state_indices_tensor_p ,
655
- current_first_idx = current_first_idx_p ,
656
- current_last_idx = current_last_idx_p ,
657
- initial_state_idx = last_state_idx_p ,
658
- context_lens = context_lens_p ,
667
+ block_idx_first_scheduled_token = block_idx_first_scheduled_token_p ,
668
+ block_idx_last_scheduled_token = block_idx_last_scheduled_token_p ,
669
+ initial_state_idx = block_idx_last_computed_token_p ,
670
+ num_computed_tokens = num_computed_tokens_p ,
659
671
block_size_to_align = mamba_block_size ,
660
672
metadata = attn_metadata ,
661
673
query_start_loc = query_start_loc_p ,
@@ -669,7 +681,7 @@ def forward_cuda(
669
681
kernel_ssm_indices = state_indices_tensor_p
670
682
if prefix_caching_enabled :
671
683
kernel_ssm_indices = state_indices_tensor_p .gather (
672
- 1 , last_state_idx_p .unsqueeze (1 )
684
+ 1 , block_idx_last_computed_token_p .unsqueeze (1 )
673
685
).squeeze (1 )
674
686
initial_states = torch .where (
675
687
has_initial_states_p [:, None , None , None ],
@@ -703,52 +715,76 @@ def forward_cuda(
703
715
)
704
716
705
717
if prefix_caching_enabled :
706
- # Save states for sequences with more than just the final state:
707
- n_blocks_to_fill = current_last_idx_p - current_first_idx_p
708
- for seq_idx in (n_blocks_to_fill > 0 ).nonzero ().squeeze (1 ):
718
+ # The chunk_stride is the number of chunks per mamba block
719
+ # e.g., if mamba_block_size = 512 and chunk_size = 256,
720
+ # then chunk_stride = 2
721
+ chunk_stride = mamba_block_size // chunk_size
722
+
723
+ # Save state for sequences with more than just final state
724
+ for seq_idx in range (num_prefills ):
725
+ # Block index for the first scheduled token
726
+ block_idx_first_scheduled_token = block_idx_first_scheduled_token_p [
727
+ seq_idx
728
+ ]
729
+
730
+ # Block index for the last scheduled token
731
+ block_idx_last_scheduled_token = block_idx_last_scheduled_token_p [
732
+ seq_idx
733
+ ]
734
+
735
+ # Number of blocks that need to be written
736
+ n_blocks_to_fill = (
737
+ block_idx_last_scheduled_token - block_idx_first_scheduled_token
738
+ )
739
+
740
+ # Skip sequences that don't have any blocks to fill
741
+ if n_blocks_to_fill == 0 :
742
+ continue
743
+
744
+ # Look up the state indices
709
745
cache_blocks_to_fill = state_indices_tensor_p [
710
746
seq_idx ,
711
- current_first_idx_p [seq_idx ] : current_first_idx_p [seq_idx ]
712
- + n_blocks_to_fill [seq_idx ],
747
+ block_idx_first_scheduled_token :block_idx_last_scheduled_token ,
713
748
]
714
- # chunks = [0 1 2 3 4 5 6 ...]
715
- # First aligned chunk would typically be:
716
- # mamba_block_size = 1024, chunk_size = 256
717
- # 1024 // 256 - 1 --> chunks[3]
718
- # But when last chunk wasn't block aligned:
719
- # - last_computed_offset_p[seq_idx] // chunk_size
720
- # e.g. 1000 // 256 -> 3 completed --> store chunk[0]
721
- # e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
722
- # e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
723
- # e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
724
- chunk_stride = mamba_block_size // chunk_size
725
- first_aligned_chunk = (
726
- torch .concat (
727
- [
728
- torch .zeros (
729
- 1 ,
730
- dtype = last_chunk_indices_p .dtype ,
731
- device = last_chunk_indices_p .device ,
732
- ),
733
- last_chunk_indices_p + 1 ,
734
- ]
735
- )[seq_idx ]
736
- + chunk_stride
737
- - 1
738
- - last_computed_offset_p [seq_idx ] // chunk_size
749
+
750
+ # First chunk index for this sequence
751
+ if seq_idx == 0 :
752
+ first_chunk = 0
753
+ else :
754
+ first_chunk = 1 + last_chunk_indices_p [seq_idx - 1 ]
755
+
756
+ # First chunk that is aligned on the mamba block boundary
757
+ first_aligned_chunk = first_chunk + chunk_stride - 1
758
+
759
+ # Calculate the number of computed tokens that were not
760
+ # already cached
761
+ num_unaligned_computed_tokens = (
762
+ num_computed_tokens_p [seq_idx ] % mamba_block_size
739
763
)
764
+
765
+ if num_unaligned_computed_tokens > 0 :
766
+ # If the number of computed tokens is not block aligned,
767
+ # then we need to shift the index accordingly
768
+ first_aligned_chunk -= (
769
+ num_unaligned_computed_tokens // chunk_size
770
+ )
771
+
772
+ # Get states to write
740
773
from_where = varlen_states [
741
774
first_aligned_chunk : first_aligned_chunk
742
- + n_blocks_to_fill [ seq_idx ] * chunk_stride : chunk_stride
775
+ + n_blocks_to_fill * chunk_stride : chunk_stride
743
776
]
777
+
778
+ # Write the states
744
779
ssm_state [cache_blocks_to_fill ] = from_where
745
780
746
- # For all seqs, store the last state (Note : might be partial):
781
+ # For all seqs, store the last state (note : might be partial):
747
782
ssm_state [
748
783
state_indices_tensor_p .gather (
749
- 1 , current_last_idx_p .unsqueeze (1 )
784
+ 1 , block_idx_last_scheduled_token_p .unsqueeze (1 )
750
785
).squeeze (1 )
751
786
] = varlen_states [last_chunk_indices_p ]
787
+
752
788
else :
753
789
# update ssm states
754
790
# - varlen state is a (num_prefills, nheads, headdim, dstate)
@@ -759,14 +795,17 @@ def forward_cuda(
759
795
if has_decode :
760
796
if prefix_caching_enabled :
761
797
state_indices_tensor_d_input = state_indices_tensor_d .gather (
762
- 1 , last_state_idx_d .unsqueeze (1 )
798
+ 1 , block_idx_last_computed_token_d .unsqueeze (1 )
763
799
).squeeze (1 )
764
800
state_indices_tensor_d_output = state_indices_tensor_d .gather (
765
- 1 , current_last_idx_d .unsqueeze (1 )
801
+ 1 , block_idx_last_scheduled_token_d .unsqueeze (1 )
766
802
).squeeze (1 )
767
- # Note:
768
- # for decode always: current_first_idx_d == current_last_idx_d
769
- # at block boundaries: current_first_idx_d > last_state_idx_d
803
+ # for decode:
804
+ # block_idx_first_scheduled_token_d ==
805
+ # block_idx_last_scheduled_token_d
806
+ # at block boundaries:
807
+ # block_idx_first_scheduled_token_d >
808
+ # block_idx_last_computed_token_d
770
809
else :
771
810
# Without caching, read and write in-place to the same blocks:
772
811
state_indices_tensor_d_input = state_indices_tensor_d
@@ -780,8 +819,8 @@ def forward_cuda(
780
819
self .conv1d .bias ,
781
820
self .activation ,
782
821
conv_state_indices = state_indices_tensor_d ,
783
- current_last_idx = current_last_idx_d ,
784
- initial_state_idx = last_state_idx_d ,
822
+ block_idx_last_scheduled_token = block_idx_last_scheduled_token_d ,
823
+ initial_state_idx = block_idx_last_computed_token_d ,
785
824
)
786
825
787
826
hidden_states_d , B_d , C_d = split_hidden_states_B_C_fn (hidden_states_B_C_d )
0 commit comments