Skip to content

Commit 534c45b

Browse files
authored
Improve fast_topk function with type hints and documentation (#22530)
Signed-off-by: zitian.zhao <[email protected]>
1 parent 3d7363e commit 534c45b

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

vllm/model_executor/models/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,23 @@ def cast_overflow_tensors(
736736
return tensors
737737

738738

739-
def fast_topk(values, topk, dim):
739+
def fast_topk(values: torch.Tensor, topk: int,
740+
dim: int) -> tuple[torch.Tensor, torch.Tensor]:
741+
"""
742+
Optimized topk implementation that uses torch.max for k=1 case.
743+
744+
This function provides better performance for the common case of k=1
745+
by using torch.max instead of the more general torch.topk.
746+
747+
Args:
748+
values: Input tensor to find top-k values from
749+
topk: Number of top values to return (k). Must be > 0.
750+
dim: Dimension along which to compute topk
751+
752+
Returns:
753+
Tuple of (values, indices) where values are the top-k values
754+
and indices are their corresponding indices in the input tensor
755+
"""
740756
if topk == 1:
741757
# Use max along the specified dimension to get both value and index
742758
return torch.max(values, dim=dim, keepdim=True)

0 commit comments

Comments
 (0)