Skip to content

Commit 29aaba5

Browse files
authored
[Perf][MTP] Optimize reject sampler in greedy situation. (#2137)
This PR port optimization in PR #2002 to main and makes it cleaner. - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@afa5b7c --------- Signed-off-by: whx-sjtu <[email protected]>
1 parent ca27400 commit 29aaba5

File tree

3 files changed

+123
-61
lines changed

3 files changed

+123
-61
lines changed

tests/e2e/singlecard/sample/test_rejection_sampler.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ def test_perfect_match(rejection_sampler):
7777

7878
metadata = create_sampling_metadata(all_greedy=True)
7979
logits = create_logits_tensor(output_tokens)
80-
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
81-
device=logits.device)
80+
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
81+
device=logits.device,
82+
dtype=torch.int32)
8283
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
8384
device=logits.device)
8485

@@ -102,8 +103,9 @@ def test_early_mismatch(rejection_sampler):
102103

103104
metadata = create_sampling_metadata(all_greedy=True)
104105
logits = create_logits_tensor(output_tokens)
105-
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
106-
device=logits.device)
106+
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
107+
device=logits.device,
108+
dtype=torch.int32)
107109
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
108110
device=logits.device)
109111

@@ -131,7 +133,9 @@ def test_multiple_sequences(rejection_sampler):
131133
metadata = create_sampling_metadata(all_greedy=True)
132134
logits = create_logits_tensor(output_tokens)
133135
bonus_token_tensor = torch.tensor(
134-
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
136+
[output_tokens[0][-1], output_tokens[1][-1]],
137+
device=logits.device,
138+
dtype=torch.int32).unsqueeze(1)
135139
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
136140
device=logits.device)
137141

@@ -155,8 +159,9 @@ def test_single_token_sequence(rejection_sampler):
155159

156160
metadata = create_sampling_metadata(all_greedy=True)
157161
logits = create_logits_tensor(output_tokens)
158-
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
159-
device=logits.device)
162+
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
163+
device=logits.device,
164+
dtype=torch.int32)
160165
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
161166
device=logits.device)
162167

@@ -178,8 +183,9 @@ def test_empty_sequence(rejection_sampler):
178183

179184
metadata = create_sampling_metadata(all_greedy=True)
180185
logits = create_logits_tensor(output_tokens)
181-
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
182-
device=logits.device)
186+
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
187+
device=logits.device,
188+
dtype=torch.int32)
183189
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
184190
device=logits.device)
185191

@@ -203,7 +209,9 @@ def test_multiple_mismatches(rejection_sampler):
203209
metadata = create_sampling_metadata(all_greedy=True)
204210
logits = create_logits_tensor(output_tokens)
205211
bonus_token_tensor = torch.tensor(
206-
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
212+
[output_tokens[0][-1], output_tokens[1][-1]],
213+
device=logits.device,
214+
dtype=torch.int32).unsqueeze(1)
207215
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
208216
device=logits.device)
209217

@@ -237,7 +245,8 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
237245
metadata = create_sampling_metadata(all_greedy=True)
238246
logits = create_logits_tensor(output_tokens)
239247
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
240-
device=logits.device)
248+
device=logits.device,
249+
dtype=torch.int32).unsqueeze(1)
241250
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
242251
device=logits.device)
243252

tests/ut/sample/test_rejection_sampler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ class TestAscendRejectionSampler(TestBase):
3232
def test_rejection_greedy_sample_pytorch(self):
3333
"""Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token"""
3434
batch_size = 2
35-
max_spec_len = 3
35+
max_spec_len = 2
3636
output_token_ids = torch.full((batch_size, max_spec_len + 1),
3737
PLACEHOLDER_TOKEN_ID)
3838

3939
cu_num_draft_tokens = torch.tensor([2, 4])
40+
num_draft_tokens = [2, 2]
4041
draft_token_ids = torch.tensor([10, 11, 20, 21])
4142
target_argmax = torch.tensor([10, 99, 20, 22])
4243
bonus_token_ids = torch.tensor([[100], [200]])
@@ -49,8 +50,9 @@ def test_rejection_greedy_sample_pytorch(self):
4950
draft_token_ids,
5051
target_argmax,
5152
bonus_token_ids,
52-
is_greedy,
53+
num_draft_tokens,
5354
max_spec_len,
55+
is_greedy,
5456
)
5557

5658
assert output_token_ids[0, 0].item() == 10

vllm_ascend/sample/rejection_sampler.py

Lines changed: 99 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -147,16 +147,25 @@ def rejection_sample(
147147
if not sampling_metadata.all_random:
148148
# Rejection sampling for greedy sampling requests.
149149
target_argmax = target_probs.argmax(dim=-1)
150-
rejection_greedy_sample_pytorch(
151-
output_token_ids,
152-
cu_num_draft_tokens,
153-
draft_token_ids,
154-
target_argmax,
155-
bonus_token_ids,
156-
is_greedy,
157-
max_spec_len,
158-
# num_warps=1,
159-
)
150+
if min(num_draft_tokens) == 1 and max(
151+
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
152+
rejection_greedy_sample_spec_len_1_pytorch(
153+
output_token_ids,
154+
draft_token_ids,
155+
target_argmax,
156+
bonus_token_ids,
157+
)
158+
else:
159+
rejection_greedy_sample_pytorch(
160+
output_token_ids,
161+
cu_num_draft_tokens,
162+
draft_token_ids,
163+
target_argmax,
164+
bonus_token_ids,
165+
num_draft_tokens,
166+
max_spec_len,
167+
is_greedy,
168+
)
160169
if sampling_metadata.all_greedy:
161170
return output_token_ids
162171

@@ -284,47 +293,89 @@ def sample_recovered_tokens(
284293
return recovered_token_ids
285294

286295

287-
def rejection_greedy_sample_pytorch(
288-
output_token_ids, # [batch_size, max_spec_len + 1]
289-
cu_num_draft_tokens, # [batch_size]
290-
draft_token_ids, # [num_tokens]
291-
target_argmax, # [num_tokens]
292-
bonus_token_ids, # [batch_size]
293-
is_greedy=None, # [batch_size] or None
294-
max_spec_len=None,
296+
def rejection_greedy_sample_spec_len_1_pytorch(
297+
output_token_ids, # [batch_size, 2]
298+
draft_token_ids, # [num_tokens]
299+
target_argmax, # [num_tokens]
300+
bonus_token_ids, # [batch_size]
295301
):
296-
batch_size = output_token_ids.shape[0]
297-
298-
if is_greedy is None:
299-
is_greedy = torch.ones(batch_size,
300-
dtype=torch.bool,
301-
device=output_token_ids.device)
302-
303-
for req_idx in range(batch_size):
304-
if not is_greedy[req_idx]:
305-
continue
306-
307-
if req_idx == 0:
308-
start_idx = 0
309-
else:
310-
start_idx = cu_num_draft_tokens[req_idx - 1].item()
311-
end_idx = cu_num_draft_tokens[req_idx].item()
312-
num_draft_tokens = end_idx - start_idx
313-
314-
rejected = False
315-
for pos in range(num_draft_tokens):
316-
if not rejected:
317-
draft_token_id = draft_token_ids[start_idx + pos].item()
318-
target_argmax_id = target_argmax[start_idx + pos].item()
319-
320-
output_token_ids[req_idx, pos] = target_argmax_id
302+
batch_size = output_token_ids.size(0)
303+
num_tokens = draft_token_ids.size(0)
304+
assert batch_size == num_tokens
305+
accept_req_mask = draft_token_ids == target_argmax
306+
output_token_ids[:, 0] = target_argmax
307+
bonus_token_ids = bonus_token_ids.squeeze(1)
308+
output_token_ids[accept_req_mask, 1] = bonus_token_ids[accept_req_mask]
321309

322-
if draft_token_id != target_argmax_id:
323-
rejected = True
324310

325-
if not rejected:
326-
bonus_token_id = bonus_token_ids[req_idx].item()
327-
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
311+
def rejection_greedy_sample_pytorch(
312+
output_token_ids, # [batch_size, max_spec_len + 1]
313+
cu_num_draft_tokens, # [batch_size]
314+
draft_token_ids, # [num_tokens]
315+
target_argmax, # [num_tokens]
316+
bonus_token_ids, # [batch_size]
317+
draft_tokens_per_req, # [batch_size], list
318+
max_spec_len,
319+
is_greedy=None, # [batch_size] or None
320+
):
321+
batch_size = output_token_ids.size(0)
322+
num_tokens = draft_token_ids.size(0)
323+
device = output_token_ids.device
324+
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
325+
device, non_blocking=True)
326+
if is_greedy is None:
327+
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
328+
329+
start_indices = cu_num_draft_tokens - draft_tokens_per_req
330+
req_ids = torch.arange(batch_size, device=device)
331+
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
332+
token_positions = torch.arange(
333+
num_tokens, device=device) - start_indices[token_req_ids]
334+
335+
# Find the first mismatch position of each request.
336+
mismatch_global = (draft_token_ids != target_argmax)
337+
if max_spec_len == 0:
338+
first_mismatch_pos_per_req = torch.zeros(batch_size,
339+
dtype=torch.long,
340+
device=device)
341+
else:
342+
# [bs, max_spec_len]
343+
pos_matrix = torch.full((batch_size, max_spec_len),
344+
-1,
345+
dtype=torch.long,
346+
device=device)
347+
pos_matrix[token_req_ids, token_positions] = token_positions
348+
mismatch_matrix = torch.full((batch_size, max_spec_len),
349+
False,
350+
dtype=torch.bool,
351+
device=device)
352+
mismatch_matrix[token_req_ids, token_positions] = mismatch_global
353+
mismatch_positions = torch.where(mismatch_matrix, pos_matrix,
354+
max_spec_len * 2)
355+
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
356+
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
357+
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
358+
no_mismatch_mask]
359+
360+
# Copy matched target tokens into output.
361+
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
362+
draft_tokens_per_req)
363+
copy_indices = torch.arange(max_spec_len + 1,
364+
device=device).expand(batch_size, -1)
365+
copy_mask = copy_indices < copy_len.unsqueeze(1)
366+
greedy_mask = is_greedy.unsqueeze(1)
367+
final_copy_mask = copy_mask & greedy_mask
368+
global_idx = start_indices.unsqueeze(1) + copy_indices
369+
output_token_ids[final_copy_mask] = target_argmax[
370+
global_idx[final_copy_mask]].to(output_token_ids.dtype)
371+
# Fill bonus token.
372+
needs_bonus = is_greedy & (first_mismatch_pos_per_req
373+
>= draft_tokens_per_req)
374+
if torch.any(needs_bonus):
375+
bonus_rows = torch.where(needs_bonus)[0]
376+
bonus_cols = draft_tokens_per_req[bonus_rows]
377+
bonus_token_ids = bonus_token_ids.squeeze(1)
378+
output_token_ids[bonus_rows, bonus_cols] = bonus_token_ids[bonus_rows]
328379

329380

330381
def rejection_random_sample_pytorch(

0 commit comments

Comments
 (0)