Skip to content

Commit 126cdfc

Browse files
authored
[Test] add rejection sampler ut (#2084)
### What this PR does / why we need it? add rejection sampler ut. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT passed - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@586f286 Signed-off-by: wangxiaoxin-sherie <[email protected]>
1 parent f3b50c5 commit 126cdfc

File tree

1 file changed

+201
-0
lines changed

1 file changed

+201
-0
lines changed
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# This file is a part of the vllm-ascend project.
14+
#
15+
from unittest.mock import patch
16+
17+
import torch
18+
19+
from tests.ut.base import TestBase
20+
from vllm_ascend.sample.rejection_sampler import (
21+
expand_batch_to_tokens, expand_pytorch, rejection_greedy_sample_pytorch,
22+
rejection_random_sample_pytorch, sample_recovered_tokens_pytorch)
23+
24+
# Global constants
25+
PLACEHOLDER_TOKEN_ID = -1
26+
GREEDY_TEMPERATURE = 0.0
27+
MAX_SPEC_LEN = 8 # Used as MAX_NUM_TOKENS in expand_batch_to_tokens
28+
29+
30+
class TestAscendRejectionSampler(TestBase):
31+
32+
def test_rejection_greedy_sample_pytorch(self):
33+
"""Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token"""
34+
batch_size = 2
35+
max_spec_len = 3
36+
output_token_ids = torch.full((batch_size, max_spec_len + 1),
37+
PLACEHOLDER_TOKEN_ID)
38+
39+
cu_num_draft_tokens = torch.tensor([2, 4])
40+
draft_token_ids = torch.tensor([10, 11, 20, 21])
41+
target_argmax = torch.tensor([10, 99, 20, 22])
42+
bonus_token_ids = torch.tensor([[100], [200]])
43+
44+
is_greedy = torch.tensor([True, True])
45+
46+
rejection_greedy_sample_pytorch(
47+
output_token_ids,
48+
cu_num_draft_tokens,
49+
draft_token_ids,
50+
target_argmax,
51+
bonus_token_ids,
52+
is_greedy,
53+
max_spec_len,
54+
)
55+
56+
assert output_token_ids[0, 0].item() == 10
57+
assert output_token_ids[0, 1].item() == 99
58+
assert output_token_ids[1, 0].item() == 20
59+
assert output_token_ids[1, 2].item() == PLACEHOLDER_TOKEN_ID
60+
61+
def test_rejection_random_sample_pytorch(self):
62+
"""Test random rejection sampling: accept based on uniform probability"""
63+
batch_size = 2
64+
max_spec_len = 3
65+
output_token_ids = torch.full((batch_size, max_spec_len + 1),
66+
PLACEHOLDER_TOKEN_ID)
67+
68+
cu_num_draft_tokens = torch.tensor([2, 1])
69+
draft_token_ids = torch.tensor([1, 0, 2])
70+
draft_probs = torch.tensor([
71+
[0.0, 0.6, 0.0, 0.4], # vocab_size=4
72+
[0.1, 0.2, 0.3, 0.4],
73+
[0.5, 0.5, 0.0, 0.0],
74+
])
75+
target_probs = torch.tensor([
76+
[0.0, 0.8, 0.0, 0.2],
77+
[0.2, 0.1, 0.3, 0.4],
78+
[0.9, 0.1, 0.0, 0.0],
79+
])
80+
bonus_token_ids = torch.tensor([[100], [200]])
81+
recovered_token_ids = torch.tensor([1, 2, 3])
82+
uniform_probs = torch.tensor([0.7, 0.6, 0.5])
83+
is_greedy = torch.tensor([False, False])
84+
vocab_size = 4
85+
86+
rejection_random_sample_pytorch(
87+
output_token_ids,
88+
cu_num_draft_tokens,
89+
draft_token_ids,
90+
draft_probs,
91+
target_probs,
92+
bonus_token_ids,
93+
recovered_token_ids,
94+
uniform_probs,
95+
is_greedy,
96+
max_spec_len,
97+
vocab_size,
98+
IS_NGRAM=False,
99+
)
100+
101+
assert output_token_ids[0, 0].item() == 1
102+
assert output_token_ids[0, 1].item() == 0
103+
assert output_token_ids[0, 2].item() == 100
104+
105+
def test_expand_pytorch(self):
106+
"""Test expand_pytorch functionality"""
107+
input_ptr = torch.tensor([10, 20, 30], dtype=torch.int32)
108+
cu_num_tokens_ptr = torch.tensor([2, 5, 7])
109+
output_ptr = torch.empty(7, dtype=torch.int32)
110+
111+
expand_pytorch(
112+
output_ptr,
113+
input_ptr,
114+
cu_num_tokens_ptr,
115+
replace_from=0,
116+
replace_to=0,
117+
MAX_NUM_TOKENS=MAX_SPEC_LEN,
118+
)
119+
120+
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
121+
assert torch.equal(output_ptr, expected)
122+
123+
def test_expand_batch_to_tokens(self):
124+
"""Test expand_batch_to_tokens wrapper"""
125+
x = torch.tensor([10, 20, 30])
126+
cu_num_tokens = torch.tensor([2, 5, 7])
127+
num_tokens = 7
128+
129+
with patch("vllm_ascend.sample.rejection_sampler.expand_pytorch"
130+
) as mock_kernel:
131+
expand_batch_to_tokens(x, cu_num_tokens, num_tokens)
132+
mock_kernel.assert_called_once()
133+
args = mock_kernel.call_args[0]
134+
assert (args[1] == x).all()
135+
assert (args[2] == cu_num_tokens).all()
136+
137+
# Run actual function
138+
result = expand_batch_to_tokens(x, cu_num_tokens, num_tokens)
139+
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
140+
assert torch.equal(result, expected)
141+
142+
def test_sample_recovered_tokens_pytorch_ngram(self):
143+
"""Test recovered token sampling under n-gram mode"""
144+
output_token_ids = torch.empty(2, dtype=torch.int32)
145+
cu_num_draft_tokens = torch.tensor([1, 2])
146+
draft_token_ids = torch.tensor([1, 2])
147+
draft_probs = None
148+
target_probs = torch.tensor([
149+
[0.1, 0.2, 0.7],
150+
[0.3, 0.3, 0.4],
151+
])
152+
q = torch.tensor([
153+
[0.1, 0.2, 0.7],
154+
[0.5, 0.4, 0.1],
155+
])
156+
vocab_size = 3
157+
158+
sample_recovered_tokens_pytorch(
159+
output_token_ids,
160+
cu_num_draft_tokens,
161+
draft_token_ids,
162+
draft_probs,
163+
target_probs,
164+
q,
165+
vocab_size,
166+
IS_NGRAM=True,
167+
)
168+
169+
assert output_token_ids[0].item() == 0
170+
assert output_token_ids[1].item() == 1
171+
172+
def test_sample_recovered_tokens_pytorch_autoregressive(self):
173+
"""Test recovered token sampling for autoregressive models"""
174+
output_token_ids = torch.empty(2, dtype=torch.int32)
175+
cu_num_draft_tokens = torch.tensor([1, 1])
176+
draft_token_ids = torch.tensor([0, 1])
177+
draft_probs = torch.tensor([
178+
[0.6, 0.1, 0.3],
179+
[0.2, 0.7, 0.1],
180+
])
181+
target_probs = torch.tensor([
182+
[0.8, 0.1, 0.1],
183+
[0.3, 0.6, 0.1],
184+
])
185+
q = torch.tensor([
186+
[0.5, 0.3, 0.2],
187+
[0.1, 0.8, 0.1],
188+
])
189+
vocab_size = 3
190+
191+
sample_recovered_tokens_pytorch(
192+
output_token_ids,
193+
cu_num_draft_tokens,
194+
draft_token_ids,
195+
draft_probs,
196+
target_probs,
197+
q,
198+
vocab_size,
199+
IS_NGRAM=False,
200+
)
201+
assert output_token_ids[0].item() == 0

0 commit comments

Comments
 (0)