Skip to content

Commit 911c8eb

Browse files
authored
[Minor][Spec Decode] Remove compiled_softmax (#15416)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent ebcebee commit 911c8eb

File tree

2 files changed

+1
-33
lines changed

2 files changed

+1
-33
lines changed

vllm/v1/sample/ops/utils.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

vllm/v1/sample/rejection_sampler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from vllm.logger import init_logger
1010
from vllm.v1.sample.metadata import SamplingMetadata
1111
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
12-
from vllm.v1.sample.ops.utils import compiled_softmax
1312
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
1413

1514
logger = init_logger(__name__)
@@ -275,8 +274,7 @@ def compute_probs(
275274
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
276275
# which is slow for large vocab sizes. This may cause performance issues.
277276
logits = apply_top_k_top_p(logits, top_k, top_p)
278-
279-
output_prob = compiled_softmax(logits)
277+
output_prob = logits.softmax(dim=-1, dtype=torch.float32)
280278
return output_prob
281279

282280

0 commit comments

Comments
 (0)