@@ -609,12 +609,8 @@ def __init__(
609609 # compatibility
610610 if not lora_config else lora_config .lora_vocab_padding_size ,
611611 )
612- # Current step used indices
613- self .current_indices : List [int ] = []
614612 # Used to track and store by the Mamba cache between steps.
615613 self .mamba_cache : Tuple [torch .Tensor , torch .Tensor ] = tuple ()
616- # Used as an input_buffer for the CUDA graph runs.
617- self .mamba_gc_cache_buffer : Tuple [torch .Tensor , torch .Tensor ] = tuple ()
618614 # Maps between the request id and a dict that maps between the seq_id
619615 # and its index inside the self.mamba_cache
620616 self .mamba_cache_indices_mapping : Dict [str , Dict [int , int ]] = {}
@@ -644,95 +640,148 @@ def forward(self,
644640 batch_size = input_ids .shape [0 ]
645641 if attn_metadata .prefill_metadata :
646642 batch_size = len (request_ids_to_seq_ids )
647- (
648- current_seqlen_agnostic_cache ,
649- indices ,
650- ) = self ._prepare_current_run_mamba_cache (request_ids_to_seq_ids ,
651- batch_size ,
652- finished_requests_ids )
643+ mamba_cache = self ._prepare_current_run_mamba_cache (
644+ request_ids_to_seq_ids , batch_size , finished_requests_ids )
653645 else :
654646 # CUDA graph capturing runs
655- current_seqlen_agnostic_cache , indices = (
656- kwargs ["seqlen_agnostic_capture_inputs" ],
657- [],
658- )
659- self .current_indices = indices
647+ mamba_cache = kwargs ["seqlen_agnostic_capture_inputs" ]
660648
661649 hidden_states = self .model (input_ids , positions , kv_caches ,
662- attn_metadata ,
663- current_seqlen_agnostic_cache [0 ],
664- current_seqlen_agnostic_cache [1 ])
665-
666- if "seqlen_agnostic_capture_inputs" not in kwargs :
667- self ._copy_mamba_cache_by_indices (self .current_indices ,
668- current_seqlen_agnostic_cache )
669-
650+ attn_metadata , mamba_cache [0 ],
651+ mamba_cache [1 ])
670652 return hidden_states
671653
672- def _copy_mamba_cache_by_indices (
673- self , indices : List [ int ],
674- current_seqlen_agnostic_cache : Tuple [ torch . Tensor , torch . Tensor ]) :
675- for i , offset in enumerate ( indices ):
676- self . _copy_mamba_cache ( offset , i , current_seqlen_agnostic_cache )
654+ def _swap_mamba_cache ( self , from_index : int , to_index : int ):
655+ assert len ( self . mamba_cache ) > 0
656+ for cache_t in self . mamba_cache :
657+ cache_t [:, [ to_index , from_index ]] = \
658+ cache_t [:, [ from_index , to_index ]]
677659
678- def _copy_mamba_cache (self , index_to : int , index_from : int ,
679- from_buffer : Tuple [torch .Tensor , torch .Tensor ]):
660+ def _copy_mamba_cache (self , from_index : int , to_index : int ):
680661 assert len (self .mamba_cache ) > 0
681- for ( cache_t , from_buffer_t ) in zip ( self .mamba_cache , from_buffer ) :
682- cache_t [:, index_to ].copy_ (from_buffer_t [:, index_from ],
662+ for cache_t in self .mamba_cache :
663+ cache_t [:, to_index ].copy_ (cache_t [:, from_index ],
683664 non_blocking = True )
684665
685- def _assign_seq_id_to_mamba_cache (self , cur_rid : str ,
686- seqs_id : List [int ]) -> List [int ]:
687- indices_for_current_run = []
688- for seq_id in seqs_id :
689- if cur_rid not in self .mamba_cache_indices_mapping :
690- self .mamba_cache_indices_mapping [cur_rid ] = {}
691- first_free_index = self ._first_free_index_in_mamba_cache ()
692- self .mamba_cache_indices_mapping [cur_rid ][
693- seq_id ] = first_free_index
694- index_for_current_run = first_free_index
695- ## case of decoding n>1, copy prefill cache to decoding indices
696- elif seq_id not in (seq_ids2indices :=
697- self .mamba_cache_indices_mapping [cur_rid ]):
698- first_free_index = self ._first_free_index_in_mamba_cache ()
699- index_exist = list (seq_ids2indices .values ())[0 ]
700- self ._copy_mamba_cache (index_from = index_exist ,
701- index_to = first_free_index ,
702- from_buffer = self .mamba_cache )
703- self .mamba_cache_indices_mapping [cur_rid ][
704- seq_id ] = first_free_index
705- index_for_current_run = first_free_index
706- else :
707- index_for_current_run = self .mamba_cache_indices_mapping [
708- cur_rid ][seq_id ]
709-
710- indices_for_current_run .append (index_for_current_run )
711- return indices_for_current_run
666+ def _move_out_if_already_occupied (self , index : int ,
667+ all_occupied_indices : List [int ]):
668+ if index in all_occupied_indices :
669+ first_free_index = self ._first_free_index_in_mamba_cache ()
670+ # In case occupied, move the occupied to a new empty block
671+ self ._move_cache_index_and_mappings (from_index = index ,
672+ to_index = first_free_index )
673+
674+ def _assign_seq_id_to_mamba_cache_in_specific_dest (self , cur_rid : str ,
675+ seq_id : int ,
676+ destination_index : int ):
677+ """
678+ Assign (req_id,seq_id) pair to a `destination_index` index, if
679+ already occupied, move the occupying index to a free index.
680+ """
681+ all_occupied_indices = self ._get_all_occupied_indices ()
682+ if cur_rid not in self .mamba_cache_indices_mapping :
683+ self ._move_out_if_already_occupied (
684+ index = destination_index ,
685+ all_occupied_indices = all_occupied_indices )
686+ self .mamba_cache_indices_mapping [cur_rid ] = {
687+ seq_id : destination_index
688+ }
689+ elif seq_id not in (seq_ids2indices :=
690+ self .mamba_cache_indices_mapping [cur_rid ]):
691+ # parallel sampling , where n > 1, assume prefill have
692+ # already happened now we only need to copy the already
693+ # existing cache into the siblings seq_ids caches
694+ self ._move_out_if_already_occupied (
695+ index = destination_index ,
696+ all_occupied_indices = all_occupied_indices )
697+ index_exists = list (seq_ids2indices .values ())[0 ]
698+ # case of decoding n>1, copy prefill cache to decoding indices
699+ self ._copy_mamba_cache (from_index = index_exists ,
700+ to_index = destination_index )
701+ self .mamba_cache_indices_mapping [cur_rid ][
702+ seq_id ] = destination_index
703+ else :
704+ # already exists
705+ cache_index_already_exists = self .mamba_cache_indices_mapping [
706+ cur_rid ][seq_id ]
707+ if cache_index_already_exists != destination_index :
708+ # In case the seq id already exists but not in
709+ # the right destination, swap it with what's occupying it
710+ self ._swap_pair_indices_and_mappings (
711+ from_index = cache_index_already_exists ,
712+ to_index = destination_index )
712713
713714 def _prepare_current_run_mamba_cache (
714- self , request_ids_to_seq_ids : Dict [str , list [int ]], batch_size : int ,
715- finished_requests_ids : List [str ]
716- ) -> Tuple [Tuple [torch .Tensor , torch .Tensor ], List [int ]]:
717- indices_for_current_run = []
718- for request_id , seqs_id in request_ids_to_seq_ids .items ():
715+ self , request_ids_to_seq_ids : Dict [str , list [int ]],
716+ batch_size : int , finished_requests_ids : List [str ]):
717+ running_indices = []
718+ request_ids_to_seq_ids_flatten = [
719+ (req_id , seq_id )
720+ for req_id , seq_ids in request_ids_to_seq_ids .items ()
721+ for seq_id in seq_ids
722+ ]
723+ for dest_index , (request_id ,
724+ seq_id ) in enumerate (request_ids_to_seq_ids_flatten ):
719725 if request_id in finished_requests_ids :
720- # Do not allocate cache for requests that run
726+ # Do not allocate cache index for requests that run
721727 # and finish right after
722728 continue
723- indices_for_current_run += self ._assign_seq_id_to_mamba_cache (
724- request_id , seqs_id )
725- ## Pad the batch in case of running batch that was not captured via CG
726- padded_indices = indices_for_current_run .copy ()
727- pad_index = self ._first_free_index_in_mamba_cache ()
729+ self ._assign_seq_id_to_mamba_cache_in_specific_dest (
730+ request_id , seq_id , dest_index )
731+ running_indices .append (dest_index )
728732
729- for _ in range (batch_size - len (indices_for_current_run )):
730- padded_indices .append (pad_index )
733+ self ._clean_up_first_bs_blocks (batch_size , running_indices )
734+ conv_state = self .mamba_cache [0 ][:, :batch_size ]
735+ temporal_state = self .mamba_cache [1 ][:, :batch_size ]
731736
732- conv_state = self .mamba_cache [0 ][:, padded_indices ]
733- temporal_state = self .mamba_cache [1 ][:, padded_indices ]
737+ return (conv_state , temporal_state )
734738
735- return (conv_state , temporal_state ), indices_for_current_run
739+ def _get_all_occupied_indices (self ):
740+ return [
741+ cache_idx
742+ for seq_ids2indices in self .mamba_cache_indices_mapping .values ()
743+ for cache_idx in seq_ids2indices .values ()
744+ ]
745+
746+ def _clean_up_first_bs_blocks (self , batch_size : int ,
747+ indices_for_current_run : List [int ]):
748+ # move out all of the occupied but currently not running blocks
749+ # outside of the first n blocks
750+ destination_indices = set ([range (batch_size )])
751+ max_possible_batch_size = self .mamba_cache [0 ].shape [1 ]
752+ for destination_index in destination_indices :
753+ if destination_index in self ._get_all_occupied_indices () and \
754+ destination_index not in indices_for_current_run :
755+ # move not running indices outside of the batch
756+ all_other_indices = list (
757+ range (batch_size , max_possible_batch_size ))
758+ first_avail_index = self ._first_free_index_in_mamba_cache (
759+ all_other_indices )
760+ self ._swap_indices (from_index = destination_index ,
761+ to_index = first_avail_index )
762+
763+ def _move_cache_index_and_mappings (self , from_index : int , to_index : int ):
764+ self ._copy_mamba_cache (from_index = from_index , to_index = to_index )
765+ self ._update_mapping_index (from_index = from_index , to_index = to_index )
766+
767+ def _swap_pair_indices_and_mappings (self , from_index : int , to_index : int ):
768+ self ._swap_mamba_cache (from_index = from_index , to_index = to_index )
769+ self ._swap_mapping_index (from_index = from_index , to_index = to_index )
770+
771+ def _swap_mapping_index (self , from_index : int , to_index : int ):
772+ for seq_ids2index in self .mamba_cache_indices_mapping .values ():
773+ for seq_id , index in seq_ids2index .items ():
774+ if from_index == index :
775+ seq_ids2index .update ({seq_id : to_index })
776+ elif to_index == index :
777+ seq_ids2index .update ({seq_id : from_index })
778+
779+ def _update_mapping_index (self , from_index : int , to_index : int ):
780+ for seq_ids2index in self .mamba_cache_indices_mapping .values ():
781+ for seq_id , index in seq_ids2index .items ():
782+ if from_index == index :
783+ seq_ids2index .update ({seq_id : to_index })
784+ return
736785
737786 def copy_inputs_before_cuda_graphs (self , input_buffers , ** kwargs ):
738787 """
@@ -747,55 +796,35 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
747796 self ._release_mamba_cache (finished_requests_ids )
748797 request_ids_to_seq_ids = kwargs ["request_ids_to_seq_ids" ]
749798 cg_batch_size = input_buffers ['input_ids' ].shape [0 ]
750- (
751- current_mamba_cache ,
752- indices ,
753- ) = self ._prepare_current_run_mamba_cache (request_ids_to_seq_ids ,
754- cg_batch_size ,
755- finished_requests_ids )
756- self .current_indices = indices
757-
758- for input_buffer , current_cache_buffer in zip (
759- input_buffers ["seqlen_agnostic_capture_inputs" ],
760- current_mamba_cache ):
761- input_buffer .copy_ (current_cache_buffer , non_blocking = True )
762-
763- def copy_outputs_after_cuda_graphs (self , input_buffers , ** kwargs ):
764- """
765- Copy the relevant Mamba cache from the CUDA graph input_buffers
766- back to the JambaForCausalLM.mamba_cache after CUDA
767- graph replay run is done.
768- """
769- self ._copy_mamba_cache_by_indices (
770- self .current_indices ,
771- input_buffers ["seqlen_agnostic_capture_inputs" ])
799+ self ._prepare_current_run_mamba_cache (request_ids_to_seq_ids ,
800+ cg_batch_size ,
801+ finished_requests_ids )
772802
773803 def get_seqlen_agnostic_capture_inputs (self , batch_size : int ):
774804 """
775805 Provide the CUDA graph capture runs with a buffer in adjusted size.
776806 The buffer is used to maintain the Mamba Cache during the CUDA graph
777807 replay runs.
778808 """
779- return tuple (buffer [:, :batch_size ]
780- for buffer in self .mamba_gc_cache_buffer )
809+ return tuple (buffer [:, :batch_size ] for buffer in self .mamba_cache )
781810
782811 def _release_mamba_cache (self , finished_seq_groups_req_ids : List [str ]):
783812 for req_id in finished_seq_groups_req_ids :
784813 if req_id in self .mamba_cache_indices_mapping :
785814 self .mamba_cache_indices_mapping .pop (req_id )
786815
787- def _first_free_index_in_mamba_cache (self ) -> int :
788- if self .mamba_cache :
816+ def _first_free_index_in_mamba_cache (
817+ self , indices_range : Optional [List [int ]] = None ) -> int :
818+ assert self .mamba_cache is not None
819+ if indices_range is None :
789820 max_possible_batch_size = self .mamba_cache [0 ].shape [1 ]
790- occupied = [
791- id for seq_ids in self .mamba_cache_indices_mapping .values ()
792- for id in seq_ids .values ()
793- ]
794- first_free_index = [
795- i not in occupied for i in range (max_possible_batch_size )
796- ].index (True )
797- return first_free_index
798- return 0
821+ indices_range = list (range (max_possible_batch_size ))
822+ all_occupied_indices = self ._get_all_occupied_indices ()
823+ for i in indices_range :
824+ if i not in all_occupied_indices :
825+ return i
826+ raise Exception ("Couldn't find a free spot in the mamba cache! This"
827+ "should never happen" )
799828
800829 def _get_mamba_cache_shape (
801830 self
@@ -819,20 +848,18 @@ def _prepare_mamba_cache(self):
819848 [layer_type == "mamba" for layer_type in layers_type ])
820849 max_batch_size = (_get_graph_batch_size (
821850 self .scheduler_config .max_num_seqs ) if self .scheduler_config else
822- max (_BATCH_SIZES_TO_CAPTURE )) + 10
851+ max (_BATCH_SIZES_TO_CAPTURE ) + 2 )
823852 conv_state_shape , temporal_state_shape = self ._get_mamba_cache_shape ()
824853 assert conv_state_shape is not None and temporal_state_shape is not None
825854
826- for buffername in ["mamba_cache" , "mamba_gc_cache_buffer" ]:
827- buffer = (torch .empty (size = (mamba_layers , max_batch_size ) +
828- conv_state_shape ,
829- dtype = dtype ,
830- device = "cuda" ),
831- torch .empty (size = (mamba_layers , max_batch_size ) +
832- temporal_state_shape ,
833- dtype = dtype ,
834- device = "cuda" ))
835- setattr (self , buffername , buffer )
855+ self .mamba_cache = (torch .empty (size = (mamba_layers , max_batch_size ) +
856+ conv_state_shape ,
857+ dtype = dtype ,
858+ device = "cuda" ),
859+ torch .empty (size = (mamba_layers , max_batch_size ) +
860+ temporal_state_shape ,
861+ dtype = dtype ,
862+ device = "cuda" ))
836863
837864 def compute_logits (self , hidden_states : torch .Tensor ,
838865 sampling_metadata : SamplingMetadata ) -> torch .Tensor :
0 commit comments