File tree Expand file tree Collapse file tree 1 file changed +17
-1
lines changed
vllm/model_executor/models Expand file tree Collapse file tree 1 file changed +17
-1
lines changed Original file line number Diff line number Diff line change @@ -736,7 +736,23 @@ def cast_overflow_tensors(
736
736
return tensors
737
737
738
738
739
- def fast_topk (values , topk , dim ):
739
+ def fast_topk (values : torch .Tensor , topk : int ,
740
+ dim : int ) -> tuple [torch .Tensor , torch .Tensor ]:
741
+ """
742
+ Optimized topk implementation that uses torch.max for k=1 case.
743
+
744
+ This function provides better performance for the common case of k=1
745
+ by using torch.max instead of the more general torch.topk.
746
+
747
+ Args:
748
+ values: Input tensor to find top-k values from
749
+ topk: Number of top values to return (k). Must be > 0.
750
+ dim: Dimension along which to compute topk
751
+
752
+ Returns:
753
+ Tuple of (values, indices) where values are the top-k values
754
+ and indices are their corresponding indices in the input tensor
755
+ """
740
756
if topk == 1 :
741
757
# Use max along the specified dimension to get both value and index
742
758
return torch .max (values , dim = dim , keepdim = True )
You can’t perform that action at this time.
0 commit comments