Skip to content

Commit 947b794

Browse files
zhuohan123Yard1
andauthored
[Sampler] Vectorized sampling (simplified) (#1048)
Co-authored-by: Antoni Baum <[email protected]>
1 parent 8d926e9 commit 947b794

File tree

3 files changed

+481
-180
lines changed

3 files changed

+481
-180
lines changed

tests/samplers/test_sampler.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import pytest
2+
import random
3+
from typing import Tuple
4+
from unittest.mock import patch
5+
6+
import torch
7+
8+
from vllm.model_executor.layers.sampler import Sampler
9+
from vllm.model_executor.utils import set_random_seed
10+
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
11+
from vllm.worker.worker import Worker
12+
13+
14+
class MockLogitsSampler(Sampler):
15+
16+
def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
17+
super().__init__(vocab_size=vocab_size)
18+
self.fake_logits = fake_logits
19+
20+
def forward(self, *args, **kwargs):
21+
with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
22+
lambda x, y: x):
23+
with patch("vllm.model_executor.layers.sampler._get_logits",
24+
lambda *args, **kwargs: self.fake_logits):
25+
return super().forward(*args, **kwargs)
26+
27+
28+
def _prepare_test(
29+
batch_size: int
30+
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
31+
vocab_size = 32000
32+
input_tensor = torch.rand((batch_size, 1024),
33+
device="cuda",
34+
dtype=torch.float16)
35+
fake_logits = torch.full((batch_size, vocab_size),
36+
1e-2,
37+
device=input_tensor.device,
38+
dtype=input_tensor.dtype)
39+
sampler = MockLogitsSampler(32000, fake_logits)
40+
worker = Worker(None, None, None)
41+
worker.block_size = 16
42+
return input_tensor, fake_logits, sampler, worker
43+
44+
45+
RANDOM_SEEDS = list(range(128))
46+
47+
48+
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
49+
def test_sampler_all_greedy(seed: int):
50+
set_random_seed(seed)
51+
batch_size = random.randint(1, 256)
52+
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
53+
54+
seq_group_metadata_list = []
55+
for i in range(batch_size):
56+
seq_group_metadata_list.append(
57+
SequenceGroupMetadata(
58+
request_id=f"test_{i}",
59+
is_prompt=True,
60+
seq_data={0: SequenceData([1, 2, 3])},
61+
sampling_params=SamplingParams(temperature=0, ),
62+
block_tables={0: [1]},
63+
))
64+
65+
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
66+
sampler_output = sampler(embedding=None,
67+
hidden_states=input_tensor,
68+
input_metadata=input_metadata)
69+
expected = torch.argmax(fake_logits, dim=-1)
70+
for i, sequence_output in enumerate(sampler_output):
71+
for nth_output in sequence_output:
72+
assert nth_output.output_token == expected[i].item()
73+
74+
75+
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
76+
def test_sampler_all_random(seed: int):
77+
set_random_seed(seed)
78+
batch_size = random.randint(1, 256)
79+
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
80+
81+
for i in range(batch_size):
82+
fake_logits[i, i] = 1e2
83+
84+
seq_group_metadata_list = []
85+
for i in range(batch_size):
86+
seq_group_metadata_list.append(
87+
SequenceGroupMetadata(
88+
request_id=f"test_{i}",
89+
is_prompt=True,
90+
seq_data={0: SequenceData([1, 2, 3])},
91+
sampling_params=SamplingParams(
92+
temperature=1.0,
93+
n=random.randint(1, 10),
94+
),
95+
block_tables={0: [1]},
96+
))
97+
98+
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
99+
sampler_output = sampler(embedding=None,
100+
hidden_states=input_tensor,
101+
input_metadata=input_metadata)
102+
for i, sequence_output in enumerate(sampler_output):
103+
for nth_output in sequence_output:
104+
assert nth_output.output_token == i
105+
106+
107+
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
108+
def test_sampler_all_beam(seed: int):
109+
set_random_seed(seed)
110+
batch_size = random.randint(1, 256)
111+
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
112+
113+
seq_group_metadata_list = []
114+
for i in range(batch_size):
115+
seq_group_metadata_list.append(
116+
SequenceGroupMetadata(
117+
request_id=f"test_{i}",
118+
is_prompt=True,
119+
seq_data={0: SequenceData([1, 2, 3])},
120+
sampling_params=SamplingParams(
121+
temperature=0,
122+
best_of=2,
123+
use_beam_search=True,
124+
),
125+
block_tables={0: [1]},
126+
))
127+
128+
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
129+
sampler(embedding=None,
130+
hidden_states=input_tensor,
131+
input_metadata=input_metadata)
132+
# no assertion here as I am not sure how to determine whether
133+
# the outputs are expected - in other words, this just tests
134+
# whether there are no exceptions in the sampler
135+
# when handling an all-beam search case.
136+
137+
138+
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
139+
def test_sampler_mixed(seed: int):
140+
set_random_seed(seed)
141+
batch_size = random.randint(1, 256)
142+
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
143+
144+
seq_group_metadata_list = []
145+
expected_tokens = []
146+
for i in range(batch_size):
147+
n = 1
148+
sampling_type = random.randint(0, 2)
149+
if sampling_type == 0:
150+
sampling_params = SamplingParams(temperature=0)
151+
elif sampling_type == 1:
152+
n = random.randint(1, 10)
153+
sampling_params = SamplingParams(
154+
temperature=random.random() + 0.1,
155+
top_p=min(random.random() + 0.1, 1),
156+
top_k=random.randint(0, 10) or -1,
157+
n=n,
158+
presence_penalty=random.randint(0, 1),
159+
)
160+
else:
161+
sampling_params = SamplingParams(temperature=0,
162+
use_beam_search=True,
163+
best_of=2)
164+
for idx in range(n):
165+
fake_logits[i, i + idx] = 1e2
166+
expected_tokens.append(i + idx)
167+
seq_group_metadata_list.append(
168+
SequenceGroupMetadata(
169+
request_id=f"test_{i}",
170+
is_prompt=True,
171+
seq_data={0: SequenceData([1, 2, 3])},
172+
sampling_params=sampling_params,
173+
block_tables={0: [1]},
174+
))
175+
176+
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
177+
sampler_output = sampler(embedding=None,
178+
hidden_states=input_tensor,
179+
input_metadata=input_metadata)
180+
for i, sequence_output in enumerate(sampler_output):
181+
if seq_group_metadata_list[i].sampling_params.use_beam_search:
182+
continue
183+
for nth_output in sequence_output:
184+
assert nth_output.output_token in expected_tokens

0 commit comments

Comments
 (0)