@@ -651,20 +651,20 @@ def forward(self,
651651 mamba_cache [1 ])
652652 return hidden_states
653653
654- def _swap_mamba_cache (self , to_index : int , from_index : int ):
654+ def _swap_mamba_cache (self , from_index : int , to_index : int ):
655655 assert len (self .mamba_cache ) > 0
656656 for cache_t in self .mamba_cache :
657657 cache_t [:, [to_index ,from_index ]] = \
658658 cache_t [:, [from_index ,to_index ]]
659659
660- def _copy_mamba_cache (self , to_index : int , from_index : int ):
660+ def _copy_mamba_cache (self , from_index : int , to_index : int ):
661661 assert len (self .mamba_cache ) > 0
662662 for cache_t in self .mamba_cache :
663663 cache_t [:, to_index ].copy_ (cache_t [:, from_index ],
664664 non_blocking = True )
665665
666666 def _move_out_if_already_occupied (self , index : int ,
667- all_occupied_indices : List [int ]):
667+ all_occupied_indices : List [int ]):
668668 if index in all_occupied_indices :
669669 first_free_index = self ._first_free_index_in_mamba_cache ()
670670 # In case occupied, move the occupied to a new empty block
@@ -674,6 +674,10 @@ def _move_out_if_already_occupied(self, index: int,
674674 def _assign_seq_id_to_mamba_cache_in_specific_dest (self , cur_rid : str ,
675675 seq_id : int ,
676676 destination_index : int ):
677+ """
678+ Assign (req_id,seq_id) pair to a `destination_index` index, if
679+ already occupied, move the occupying index to a free index.
680+ """
677681 all_occupied_indices = self ._get_all_occupied_indices ()
678682 if cur_rid not in self .mamba_cache_indices_mapping :
679683 self ._move_out_if_already_occupied (
@@ -697,7 +701,7 @@ def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
697701 self .mamba_cache_indices_mapping [cur_rid ][
698702 seq_id ] = destination_index
699703 else :
700- ## already exists
704+ # already exists
701705 cache_index_already_exists = self .mamba_cache_indices_mapping [
702706 cur_rid ][seq_id ]
703707 if cache_index_already_exists != destination_index :
0 commit comments