File tree Expand file tree Collapse file tree 1 file changed +11
-1
lines changed
vllm/model_executor/models Expand file tree Collapse file tree 1 file changed +11
-1
lines changed Original file line number Diff line number Diff line change @@ -717,6 +717,15 @@ def compute_attn_mask_seqlen(
717717 seqlens = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).tolist ()
718718 return max_seqlen , seqlens
719719
720+ @staticmethod
721+ def invert_permutation (perm : torch .Tensor ) -> torch .Tensor :
722+ # building the inverse permutation in O(n) time
723+ inv = torch .empty_like (perm )
724+ inv [perm ] = torch .arange (perm .numel (),
725+ device = perm .device ,
726+ dtype = perm .dtype )
727+ return inv
728+
720729 def forward (
721730 self ,
722731 x : torch .Tensor ,
@@ -760,6 +769,8 @@ def forward(
760769
761770 rotary_pos_emb = torch .cat (rotary_pos_emb )
762771 window_index = torch .cat (window_index )
772+ # compute reverse indices
773+ reverse_indices = self .invert_permutation (window_index )
763774 cu_window_seqlens = torch .cat (cu_window_seqlens )
764775 cu_window_seqlens = torch .unique_consecutive (cu_window_seqlens )
765776 cu_seqlens = torch .cat (cu_seqlens )
@@ -813,7 +824,6 @@ def forward(
813824
814825 # adapter
815826 hidden_states = self .merger (hidden_states )
816- reverse_indices = torch .argsort (window_index )
817827 hidden_states = hidden_states [reverse_indices , :]
818828 return hidden_states
819829
You can’t perform that action at this time.
0 commit comments