Skip to content

Commit 0af3d4f

Browse files
authored
[FEAT] [AITER] [ROCm] integrate aiter sampling ops (#26084)
Signed-off-by: vllmellm <[email protected]>
1 parent da8dadf commit 0af3d4f

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from packaging import version
88

99
from vllm import envs
10+
from vllm._aiter_ops import rocm_aiter_ops
1011
from vllm.config.model import LogprobsMode
1112
from vllm.logger import init_logger
1213
from vllm.platforms import CpuArchEnum, current_platform
@@ -55,6 +56,17 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
5556
self.forward = self.forward_native
5657
else:
5758
self.forward = self.forward_cpu
59+
elif (
60+
logprobs_mode not in ("processed_logits", "processed_logprobs")
61+
and rocm_aiter_ops.is_enabled()
62+
):
63+
import aiter.ops.sampling # noqa: F401
64+
65+
self.aiter_ops = torch.ops.aiter
66+
logger.info_once(
67+
"Using aiter sampler on ROCm (lazy import, sampling-only)."
68+
)
69+
self.forward = self.forward_hip
5870
else:
5971
self.forward = self.forward_native
6072

@@ -138,6 +150,64 @@ def forward_cpu(
138150

139151
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
140152

153+
def forward_hip(
154+
self,
155+
logits: torch.Tensor,
156+
generators: dict[int, torch.Generator],
157+
k: torch.Tensor | None,
158+
p: torch.Tensor | None,
159+
) -> tuple[torch.Tensor, torch.Tensor | None]:
160+
"""Optimized ROCm/aiter path (same structure as forward_cuda)."""
161+
if (k is None and p is None) or generators:
162+
if generators:
163+
logger.warning_once(
164+
"aiter sampler does not support per-request generators; "
165+
"falling back to PyTorch-native."
166+
)
167+
return self.forward_native(logits, generators, k, p)
168+
assert self.logprobs_mode not in (
169+
"processed_logits",
170+
"processed_logprobs",
171+
), "aiter sampler does not support returning logits/logprobs."
172+
return self.aiter_sample(logits, k, p, generators), None
173+
174+
def aiter_sample(
175+
self,
176+
logits: torch.Tensor,
177+
k: torch.Tensor | None,
178+
p: torch.Tensor | None,
179+
generators: dict[int, torch.Generator],
180+
) -> torch.Tensor:
181+
"""Sample from logits using aiter ops."""
182+
use_top_k = k is not None
183+
use_top_p = p is not None
184+
# Joint k+p path
185+
if use_top_p and use_top_k:
186+
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
187+
next_token_ids = self.aiter_ops.top_k_top_p_sampling_from_probs(
188+
probs,
189+
None,
190+
*_to_tensor_scalar_tuple(k),
191+
*_to_tensor_scalar_tuple(p),
192+
deterministic=True,
193+
)
194+
return next_token_ids.view(-1)
195+
# Top-p only path
196+
elif use_top_p:
197+
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
198+
next_token_ids = self.aiter_ops.top_p_sampling_from_probs(
199+
probs, None, *_to_tensor_scalar_tuple(p), deterministic=True
200+
)
201+
return next_token_ids.view(-1)
202+
# Top-k only path
203+
elif use_top_k:
204+
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
205+
renorm_probs = self.aiter_ops.top_k_renorm_probs(
206+
probs, *_to_tensor_scalar_tuple(k)
207+
)
208+
return torch.multinomial(renorm_probs, num_samples=1).view(-1)
209+
raise RuntimeError("aiter_sample was called with no active top-k or top-p.")
210+
141211

142212
# Note: this is a workaround for
143213
# https://github.com/pytorch/pytorch/pull/151218
@@ -288,3 +358,10 @@ def flashinfer_sample(
288358
)
289359

290360
return next_token_ids.view(-1)
361+
362+
363+
def _to_tensor_scalar_tuple(x):
364+
if isinstance(x, torch.Tensor):
365+
return (x, 0)
366+
else:
367+
return (None, x)

0 commit comments

Comments
 (0)