Skip to content

Commit e9af6ba

Browse files
authored
[Model Runner V2] Optimize Gumbel Sampling Kernel (#29210)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent c6fa389 commit e9af6ba

File tree

1 file changed

+43
-50
lines changed

1 file changed

+43
-50
lines changed

vllm/v1/worker/gpu/sampler.py

Lines changed: 43 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
from collections.abc import Callable
44

55
import torch
6-
import triton
7-
import triton.language as tl
86

97
from vllm.config.model import LogprobsMode
8+
from vllm.triton_utils import tl, triton
109
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
1110
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
1211
from vllm.v1.worker.gpu.states import SamplingMetadata
@@ -78,7 +77,10 @@ def sample(
7877

7978
@triton.jit
8079
def _gumbel_sample_kernel(
81-
sampled_ptr,
80+
local_argmax_ptr,
81+
local_argmax_stride,
82+
local_max_ptr,
83+
local_max_stride,
8284
logits_ptr,
8385
logits_stride,
8486
seeds_ptr,
@@ -88,57 +90,35 @@ def _gumbel_sample_kernel(
8890
BLOCK_SIZE: tl.constexpr,
8991
):
9092
req_idx = tl.program_id(0)
91-
is_greedy = tl.load(is_greedy_ptr + req_idx)
92-
93-
if is_greedy:
94-
# Greedy sampling. Don't apply gumbel noise.
95-
max_val = float("-inf")
96-
max_idx = 0
97-
for i in range(0, vocab_size, BLOCK_SIZE):
98-
block = i + tl.arange(0, BLOCK_SIZE)
99-
mask = block < vocab_size
100-
logits = tl.load(
101-
logits_ptr + req_idx * logits_stride + block,
102-
mask=mask,
103-
other=float("-inf"),
104-
)
105-
106-
idx = tl.argmax(logits, axis=0)
107-
value = tl.max(logits, axis=0)
108-
is_greater = value > max_val
109-
max_val = tl.where(is_greater, value, max_val)
110-
max_idx = tl.where(is_greater, i + idx, max_idx)
111-
tl.store(sampled_ptr + req_idx, max_idx)
112-
return
113-
114-
# Random sampling.
115-
# Calculate gumbel seed.
116-
seed = tl.load(seeds_ptr + req_idx)
117-
pos = tl.load(pos_ptr + req_idx)
118-
gumbel_seed = tl.randint(seed, pos)
93+
block_idx = tl.program_id(1)
94+
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
95+
mask = block < vocab_size
96+
logits = tl.load(
97+
logits_ptr + req_idx * logits_stride + block,
98+
mask=mask,
99+
other=float("-inf"),
100+
)
119101

120-
max_val = float("-inf")
121-
max_idx = 0
122-
for i in range(0, vocab_size, BLOCK_SIZE):
123-
block = i + tl.arange(0, BLOCK_SIZE)
124-
mask = block < vocab_size
102+
is_greedy = tl.load(is_greedy_ptr + req_idx)
103+
if not is_greedy:
104+
# Calculate the seed for gumbel noise.
105+
seed = tl.load(seeds_ptr + req_idx)
106+
pos = tl.load(pos_ptr + req_idx)
107+
gumbel_seed = tl.randint(seed, pos)
125108

126109
# Generate gumbel noise.
127110
r = tl.rand(gumbel_seed, block).to(tl.float64)
128111
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
129112
gumbel_noise = gumbel_noise.to(tl.float32)
130113

131114
# Apply gumbel noise.
132-
logits = tl.load(logits_ptr + req_idx * logits_stride + block, mask=mask)
133115
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
134116

135-
# Argmax to get the sampled token.
136-
idx = tl.argmax(logits, axis=0)
137-
value = tl.max(logits, axis=0)
138-
is_greater = value > max_val
139-
max_val = tl.where(is_greater, value, max_val)
140-
max_idx = tl.where(is_greater, i + idx, max_idx)
141-
tl.store(sampled_ptr + req_idx, max_idx)
117+
idx = tl.argmax(logits, axis=0)
118+
token_id = block_idx * BLOCK_SIZE + idx
119+
value = tl.max(logits, axis=0)
120+
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
121+
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
142122

143123

144124
def gumbel_sample(
@@ -148,23 +128,36 @@ def gumbel_sample(
148128
pos: torch.Tensor, # [num_reqs]
149129
) -> torch.Tensor:
150130
num_reqs, vocab_size = logits.shape
151-
# NOTE(woosuk): Use int64 for later indexing.
152-
sampled = torch.empty(
131+
BLOCK_SIZE = 1024
132+
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
133+
local_argmax = torch.empty(
153134
num_reqs,
135+
num_blocks,
154136
dtype=torch.int64,
155137
device=logits.device,
156138
)
157-
_gumbel_sample_kernel[(num_reqs,)](
158-
sampled,
139+
local_max = torch.empty(
140+
num_reqs,
141+
num_blocks,
142+
dtype=torch.float32,
143+
device=logits.device,
144+
)
145+
_gumbel_sample_kernel[(num_reqs, num_blocks)](
146+
local_argmax,
147+
local_argmax.stride(0),
148+
local_max,
149+
local_max.stride(0),
159150
logits,
160151
logits.stride(0),
161152
seed,
162153
pos,
163154
is_greedy,
164155
vocab_size,
165-
num_warps=8,
166-
BLOCK_SIZE=16384, # type: ignore
156+
BLOCK_SIZE=BLOCK_SIZE,
167157
)
158+
# NOTE(woosuk): Use int64 for later indexing.
159+
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
160+
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
168161
return sampled
169162

170163

0 commit comments

Comments
 (0)