|
7 | 7 | from packaging import version |
8 | 8 |
|
9 | 9 | from vllm import envs |
| 10 | +from vllm._aiter_ops import rocm_aiter_ops |
10 | 11 | from vllm.config.model import LogprobsMode |
11 | 12 | from vllm.logger import init_logger |
12 | 13 | from vllm.platforms import CpuArchEnum, current_platform |
@@ -55,6 +56,17 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: |
55 | 56 | self.forward = self.forward_native |
56 | 57 | else: |
57 | 58 | self.forward = self.forward_cpu |
| 59 | + elif ( |
| 60 | + logprobs_mode not in ("processed_logits", "processed_logprobs") |
| 61 | + and rocm_aiter_ops.is_enabled() |
| 62 | + ): |
| 63 | + import aiter.ops.sampling # noqa: F401 |
| 64 | + |
| 65 | + self.aiter_ops = torch.ops.aiter |
| 66 | + logger.info_once( |
| 67 | + "Using aiter sampler on ROCm (lazy import, sampling-only)." |
| 68 | + ) |
| 69 | + self.forward = self.forward_hip |
58 | 70 | else: |
59 | 71 | self.forward = self.forward_native |
60 | 72 |
|
@@ -138,6 +150,64 @@ def forward_cpu( |
138 | 150 |
|
139 | 151 | return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return |
140 | 152 |
|
| 153 | + def forward_hip( |
| 154 | + self, |
| 155 | + logits: torch.Tensor, |
| 156 | + generators: dict[int, torch.Generator], |
| 157 | + k: torch.Tensor | None, |
| 158 | + p: torch.Tensor | None, |
| 159 | + ) -> tuple[torch.Tensor, torch.Tensor | None]: |
| 160 | + """Optimized ROCm/aiter path (same structure as forward_cuda).""" |
| 161 | + if (k is None and p is None) or generators: |
| 162 | + if generators: |
| 163 | + logger.warning_once( |
| 164 | + "aiter sampler does not support per-request generators; " |
| 165 | + "falling back to PyTorch-native." |
| 166 | + ) |
| 167 | + return self.forward_native(logits, generators, k, p) |
| 168 | + assert self.logprobs_mode not in ( |
| 169 | + "processed_logits", |
| 170 | + "processed_logprobs", |
| 171 | + ), "aiter sampler does not support returning logits/logprobs." |
| 172 | + return self.aiter_sample(logits, k, p, generators), None |
| 173 | + |
| 174 | + def aiter_sample( |
| 175 | + self, |
| 176 | + logits: torch.Tensor, |
| 177 | + k: torch.Tensor | None, |
| 178 | + p: torch.Tensor | None, |
| 179 | + generators: dict[int, torch.Generator], |
| 180 | + ) -> torch.Tensor: |
| 181 | + """Sample from logits using aiter ops.""" |
| 182 | + use_top_k = k is not None |
| 183 | + use_top_p = p is not None |
| 184 | + # Joint k+p path |
| 185 | + if use_top_p and use_top_k: |
| 186 | + probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous() |
| 187 | + next_token_ids = self.aiter_ops.top_k_top_p_sampling_from_probs( |
| 188 | + probs, |
| 189 | + None, |
| 190 | + *_to_tensor_scalar_tuple(k), |
| 191 | + *_to_tensor_scalar_tuple(p), |
| 192 | + deterministic=True, |
| 193 | + ) |
| 194 | + return next_token_ids.view(-1) |
| 195 | + # Top-p only path |
| 196 | + elif use_top_p: |
| 197 | + probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous() |
| 198 | + next_token_ids = self.aiter_ops.top_p_sampling_from_probs( |
| 199 | + probs, None, *_to_tensor_scalar_tuple(p), deterministic=True |
| 200 | + ) |
| 201 | + return next_token_ids.view(-1) |
| 202 | + # Top-k only path |
| 203 | + elif use_top_k: |
| 204 | + probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous() |
| 205 | + renorm_probs = self.aiter_ops.top_k_renorm_probs( |
| 206 | + probs, *_to_tensor_scalar_tuple(k) |
| 207 | + ) |
| 208 | + return torch.multinomial(renorm_probs, num_samples=1).view(-1) |
| 209 | + raise RuntimeError("aiter_sample was called with no active top-k or top-p.") |
| 210 | + |
141 | 211 |
|
142 | 212 | # Note: this is a workaround for |
143 | 213 | # https://github.com/pytorch/pytorch/pull/151218 |
@@ -288,3 +358,10 @@ def flashinfer_sample( |
288 | 358 | ) |
289 | 359 |
|
290 | 360 | return next_token_ids.view(-1) |
| 361 | + |
| 362 | + |
| 363 | +def _to_tensor_scalar_tuple(x): |
| 364 | + if isinstance(x, torch.Tensor): |
| 365 | + return (x, 0) |
| 366 | + else: |
| 367 | + return (None, x) |
0 commit comments