@@ -663,30 +663,35 @@ def _copy_mamba_cache(self, to_index: int, from_index: int):
663663 cache_t [:, to_index ].copy_ (cache_t [:, from_index ],
664664 non_blocking = True )
665665
666+ def _move_out_if_already_occupied (self , index : int ,
667+ all_occupied_indices : List [int ]):
668+ if index in all_occupied_indices :
669+ first_free_index = self ._first_free_index_in_mamba_cache ()
670+ # In case occupied, move the occupied to a new empty block
671+ self ._move_cache_index_and_mappings (from_index = index ,
672+ to_index = first_free_index )
673+
666674 def _assign_seq_id_to_mamba_cache_in_specific_dest (self , cur_rid : str ,
667675 seq_id : int ,
668676 destination_index : int ):
669677 all_occupied_indices = self ._get_all_occupied_indices ()
670678 if cur_rid not in self .mamba_cache_indices_mapping :
671- ## assign new free index
672- if destination_index in all_occupied_indices :
673- # In case occupied, move the occupied to a new empty block
674- self ._move_cache_index_and_mappings (
675- from_index = destination_index ,
676- to_index = self ._first_free_index_in_mamba_cache ())
679+ self ._move_out_if_already_occupied (
680+ index = destination_index ,
681+ all_occupied_indices = all_occupied_indices )
677682 self .mamba_cache_indices_mapping [cur_rid ] = {
678683 seq_id : destination_index
679684 }
680685 elif seq_id not in (seq_ids2indices :=
681686 self .mamba_cache_indices_mapping [cur_rid ]):
682- # N > 1
683- first_free_index = self . _first_free_index_in_mamba_cache ()
684- if destination_index in all_occupied_indices :
685- # In case occupied, move the occupied to a new empty block
686- self . _move_cache_index_and_mappings (
687- from_index = destination_index , to_index = first_free_index )
687+ # parallel sampling , where n > 1, assume prefill already happend
688+ # now we only need to copy the already existing cache into the
689+ # siblings seq_ids caches
690+ self . _move_out_if_already_occupied (
691+ index = destination_index ,
692+ all_occupied_indices = all_occupied_indices )
688693 index_exists = list (seq_ids2indices .values ())[0 ]
689- ## case of decoding n>1, copy prefill cache to decoding indices
694+ # case of decoding n>1, copy prefill cache to decoding indices
690695 self ._copy_mamba_cache (from_index = index_exists ,
691696 to_index = destination_index )
692697 self .mamba_cache_indices_mapping [cur_rid ][
0 commit comments