Skip to content

Commit 002800f

Browse files
authored
Align vLLM's beam search implementation with HF generate (#857)
1 parent e15932b commit 002800f

24 files changed

+599
-263
lines changed

docs/source/models/adding_model.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
5959
+ kv_caches: List[KVCache],
6060
+ input_metadata: InputMetadata,
6161
+ cache_events: Optional[List[torch.cuda.Event]],
62-
+) -> Dict[int, SequenceOutputs]:
62+
+) -> SamplerOutput:
6363
6464
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
6565
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.

tests/conftest.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def generate(
6767
output_ids,
6868
skip_special_tokens=True,
6969
clean_up_tokenization_spaces=False,
70-
)[0]
71-
output_ids = output_ids[0].cpu().tolist()
70+
)
71+
output_ids = output_ids.cpu().tolist()
7272
outputs.append((output_ids, output_str))
7373
return outputs
7474

@@ -77,8 +77,34 @@ def generate_greedy(
7777
prompts: List[str],
7878
max_tokens: int,
7979
) -> List[Tuple[List[int], str]]:
80-
return self.generate(prompts, do_sample=False,
81-
max_new_tokens=max_tokens)
80+
outputs = self.generate(prompts,
81+
do_sample=False,
82+
max_new_tokens=max_tokens)
83+
for i in range(len(outputs)):
84+
output_ids, output_str = outputs[i]
85+
outputs[i] = (output_ids[0], output_str[0])
86+
return outputs
87+
88+
def generate_beam_search(
89+
self,
90+
prompts: List[str],
91+
beam_width: int,
92+
max_tokens: int,
93+
) -> List[Tuple[List[int], str]]:
94+
outputs = self.generate(prompts,
95+
do_sample=False,
96+
max_new_tokens=max_tokens,
97+
num_beams=beam_width,
98+
num_return_sequences=beam_width)
99+
for i in range(len(outputs)):
100+
output_ids, output_str = outputs[i]
101+
for j in range(len(output_ids)):
102+
output_ids[j] = [
103+
x for x in output_ids[j]
104+
if x != self.tokenizer.pad_token_id
105+
]
106+
outputs[i] = (output_ids, output_str)
107+
return outputs
82108

83109

84110
@pytest.fixture
@@ -107,15 +133,20 @@ def generate(
107133
prompts: List[str],
108134
sampling_params: SamplingParams,
109135
) -> List[Tuple[List[int], str]]:
110-
req_outputs = self.model.generate(
111-
prompts, sampling_params=sampling_params)
136+
req_outputs = self.model.generate(prompts,
137+
sampling_params=sampling_params)
112138
outputs = []
113139
for req_output in req_outputs:
114140
prompt_str = req_output.prompt
115141
prompt_ids = req_output.prompt_token_ids
116-
output_str = req_output.outputs[0].text
117-
output_ids = req_output.outputs[0].token_ids
118-
outputs.append((prompt_ids + output_ids, prompt_str + output_str))
142+
req_sample_output_ids = []
143+
req_sample_output_strs = []
144+
for sample in req_output.outputs:
145+
output_str = sample.text
146+
output_ids = sample.token_ids
147+
req_sample_output_ids.append(prompt_ids + output_ids)
148+
req_sample_output_strs.append(prompt_str + output_str)
149+
outputs.append((req_sample_output_ids, req_sample_output_strs))
119150
return outputs
120151

121152
def generate_greedy(
@@ -124,7 +155,22 @@ def generate_greedy(
124155
max_tokens: int,
125156
) -> List[Tuple[List[int], str]]:
126157
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
127-
return self.generate(prompts, greedy_params)
158+
outputs = self.generate(prompts, greedy_params)
159+
return [(output_ids[0], output_str[0]) for output_ids, output_str in
160+
outputs]
161+
162+
def generate_beam_search(
163+
self,
164+
prompts: List[str],
165+
beam_width: int,
166+
max_tokens: int,
167+
) -> List[Tuple[List[int], str]]:
168+
beam_search_params = SamplingParams(n=beam_width,
169+
use_beam_search=True,
170+
temperature=0.0,
171+
max_tokens=max_tokens)
172+
outputs = self.generate(prompts, beam_search_params)
173+
return outputs
128174

129175

130176
@pytest.fixture

tests/samplers/test_beam_search.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Compare the outputs of HF and vLLM when using beam search.
2+
3+
Run `pytest tests/samplers/test_beam_search.py --forked`.
4+
"""
5+
import pytest
6+
7+
# FIXME(zhuohan): The test can not pass if we:
8+
# 1. Increase max_tokens to 256.
9+
# 2. Increase beam_width to 8.
10+
# 3. Use the model "huggyllama/llama-7b".
11+
MAX_TOKENS = [128]
12+
BEAM_WIDTHS = [4]
13+
MODELS = ["facebook/opt-125m"]
14+
15+
16+
@pytest.mark.parametrize("model", MODELS)
17+
@pytest.mark.parametrize("dtype", ["half"])
18+
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
19+
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
20+
def test_beam_search_single_input(
21+
hf_runner,
22+
vllm_runner,
23+
example_prompts,
24+
model: str,
25+
dtype: str,
26+
max_tokens: int,
27+
beam_width: int,
28+
) -> None:
29+
hf_model = hf_runner(model, dtype=dtype)
30+
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
31+
max_tokens)
32+
del hf_model
33+
34+
vllm_model = vllm_runner(model, dtype=dtype)
35+
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
36+
max_tokens)
37+
del vllm_model
38+
39+
for i in range(len(example_prompts)):
40+
hf_output_ids, _ = hf_outputs[i]
41+
vllm_output_ids, _ = vllm_outputs[i]
42+
assert len(hf_output_ids) == len(vllm_output_ids)
43+
for j in range(len(hf_output_ids)):
44+
assert hf_output_ids[j] == vllm_output_ids[j], (
45+
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
46+
f"vLLM: {vllm_output_ids}")

vllm/core/block_manager.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,7 @@ def can_swap_in(self, seq_group: SequenceGroup) -> bool:
172172
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
173173
# CPU block -> GPU block.
174174
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
175-
for seq in seq_group.get_seqs():
176-
if seq.is_finished():
177-
continue
175+
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
178176
new_block_table: BlockTable = []
179177
block_table = self.block_tables[seq.seq_id]
180178

@@ -203,9 +201,7 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool:
203201
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
204202
# GPU block -> CPU block.
205203
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
206-
for seq in seq_group.get_seqs():
207-
if seq.is_finished():
208-
continue
204+
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
209205
new_block_table: BlockTable = []
210206
block_table = self.block_tables[seq.seq_id]
211207

vllm/core/scheduler.py

Lines changed: 45 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
from vllm.core.policy import PolicyFactory
88
from vllm.logger import init_logger
99
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
10-
SequenceGroupMetadata, SequenceOutputs,
11-
SequenceStatus)
10+
SequenceGroupMetadata, SequenceStatus)
1211

1312
logger = init_logger(__name__)
1413

@@ -76,6 +75,7 @@ def __init__(
7675
num_cpu_blocks=self.cache_config.num_cpu_blocks,
7776
)
7877

78+
# TODO(zhuohan): Use deque instead of list for better performance.
7979
# Sequence groups in the WAITING state.
8080
self.waiting: List[SequenceGroup] = []
8181
# Sequence groups in the RUNNING state.
@@ -96,10 +96,11 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
9696
if seq_group.request_id in request_ids:
9797
# Remove the sequence group from the state queue.
9898
state_queue.remove(seq_group)
99-
for seq in seq_group.seqs:
99+
for seq in seq_group.get_seqs():
100100
if seq.is_finished():
101101
continue
102-
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
102+
seq.status = SequenceStatus.FINISHED_ABORTED
103+
self.free_seq(seq)
103104
request_ids.remove(seq_group.request_id)
104105
if not request_ids:
105106
return
@@ -123,13 +124,20 @@ def _schedule(self) -> SchedulerOutputs:
123124
if not self.swapped:
124125
ignored_seq_groups: List[SequenceGroup] = []
125126
scheduled: List[SequenceGroup] = []
127+
# The total number of sequences on the fly, including the
128+
# requests in the generation phase.
129+
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
130+
for seq_group in self.running)
126131
num_batched_tokens = 0
127132
# Optimization: We do not sort the waiting queue since the preempted
128133
# sequence groups are added to the front and the new sequence groups
129134
# are added to the back.
130135
while self.waiting:
131136
seq_group = self.waiting[0]
132137

138+
assert seq_group.num_seqs() == 1, (
139+
"Waiting sequence group should have only one prompt "
140+
"sequence.")
133141
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
134142
if num_prompt_tokens > self.prompt_limit:
135143
logger.warning(
@@ -152,11 +160,7 @@ def _schedule(self) -> SchedulerOutputs:
152160

153161
# The total number of sequences in the RUNNING state should not
154162
# exceed the maximum number of sequences.
155-
num_new_seqs = seq_group.num_seqs(
156-
status=SequenceStatus.WAITING)
157-
num_curr_seqs = sum(
158-
seq_group.num_seqs(status=SequenceStatus.RUNNING)
159-
for seq_group in self.running)
163+
num_new_seqs = seq_group.get_max_num_running_seqs()
160164
if (num_curr_seqs + num_new_seqs >
161165
self.scheduler_config.max_num_seqs):
162166
break
@@ -165,6 +169,7 @@ def _schedule(self) -> SchedulerOutputs:
165169
self._allocate(seq_group)
166170
self.running.append(seq_group)
167171
num_batched_tokens += num_prompt_tokens
172+
num_curr_seqs += num_new_seqs
168173
scheduled.append(seq_group)
169174

170175
if scheduled:
@@ -210,30 +215,32 @@ def _schedule(self) -> SchedulerOutputs:
210215

211216
# Swap in the sequence groups in the SWAPPED state if possible.
212217
self.swapped = self.policy.sort_by_priority(now, self.swapped)
213-
while self.swapped and not blocks_to_swap_out:
214-
seq_group = self.swapped[0]
215-
# If the sequence group has been preempted in this step, stop.
216-
if seq_group in preempted:
217-
break
218-
# If the sequence group cannot be swapped in, stop.
219-
if not self.block_manager.can_swap_in(seq_group):
220-
break
221-
222-
# The total number of sequences in the RUNNING state should not
223-
# exceed the maximum number of sequences.
224-
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
225-
num_curr_seqs = sum(
226-
seq_group.num_seqs(status=SequenceStatus.RUNNING)
227-
for seq_group in self.running)
228-
if (num_curr_seqs + num_new_seqs >
229-
self.scheduler_config.max_num_seqs):
230-
break
231-
232-
seq_group = self.swapped.pop(0)
233-
self._swap_in(seq_group, blocks_to_swap_in)
234-
self._append_slot(seq_group, blocks_to_copy)
235-
self.running.append(seq_group)
218+
if not preempted:
219+
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
220+
for seq_group in self.running)
221+
222+
while self.swapped:
223+
seq_group = self.swapped[0]
224+
# If the sequence group cannot be swapped in, stop.
225+
if not self.block_manager.can_swap_in(seq_group):
226+
break
236227

228+
# The total number of sequences in the RUNNING state should not
229+
# exceed the maximum number of sequences.
230+
num_new_seqs = seq_group.get_max_num_running_seqs()
231+
if (num_curr_seqs + num_new_seqs >
232+
self.scheduler_config.max_num_seqs):
233+
break
234+
235+
seq_group = self.swapped.pop(0)
236+
self._swap_in(seq_group, blocks_to_swap_in)
237+
self._append_slot(seq_group, blocks_to_copy)
238+
num_curr_seqs += num_new_seqs
239+
self.running.append(seq_group)
240+
241+
# Each sequence in the generation phase only takes one token slot.
242+
# Therefore, the number of batched tokens is equal to the number of
243+
# sequences in the RUNNING state.
237244
num_batched_tokens = sum(
238245
seq_group.num_seqs(status=SequenceStatus.RUNNING)
239246
for seq_group in self.running)
@@ -275,40 +282,10 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
275282
seq_group_metadata_list.append(seq_group_metadata)
276283
return seq_group_metadata_list, scheduler_outputs
277284

278-
def update(
279-
self,
280-
seq_outputs: Dict[int, SequenceOutputs],
281-
) -> List[SequenceGroup]:
282-
scheduled: List[SequenceGroup] = []
283-
for seq_group in self.running:
284-
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
285-
if seq.seq_id in seq_outputs:
286-
scheduled.append(seq_group)
287-
break
288-
289-
# Update the scheduled sequences and free blocks.
290-
for seq_group in scheduled:
291-
# Process beam search results before processing the new tokens.
292-
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
293-
output = seq_outputs[seq.seq_id]
294-
if seq.seq_id != output.parent_seq_id:
295-
# The sequence is a fork of the parent sequence (beam
296-
# search). Free the current sequence.
297-
self.block_manager.free(seq)
298-
# Fork the parent sequence.
299-
parent_seq = seq_group.find(output.parent_seq_id)
300-
parent_seq.fork(seq)
301-
self.block_manager.fork(parent_seq, seq)
302-
303-
# Process the new tokens.
304-
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
305-
# Append a new token to the sequence.
306-
output = seq_outputs[seq.seq_id]
307-
seq.append_token_id(output.output_token, output.logprobs)
308-
return scheduled
285+
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
286+
self.block_manager.fork(parent_seq, child_seq)
309287

310-
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
311-
seq.status = finish_status
288+
def free_seq(self, seq: Sequence) -> None:
312289
self.block_manager.free(seq)
313290

314291
def free_finished_seq_groups(self) -> None:
@@ -345,17 +322,16 @@ def _preempt(
345322
# If preemption mode is not specified, we determine the mode as follows:
346323
# We use recomputation by default since it incurs lower overhead than
347324
# swapping. However, when the sequence group has multiple sequences
348-
# (e.g., beam search), recomputation is not supported. In such a case,
349-
# we use swapping instead.
325+
# (e.g., beam search), recomputation is not currently supported. In
326+
# such a case, we use swapping instead.
350327
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
351328
# As swapped sequences are prioritized over waiting sequences,
352329
# sequence groups with multiple sequences are implicitly prioritized
353330
# over sequence groups with a single sequence.
354331
# TODO(woosuk): Support recomputation for sequence groups with multiple
355332
# sequences. This may require a more sophisticated CUDA kernel.
356333
if preemption_mode is None:
357-
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
358-
if len(seqs) == 1:
334+
if seq_group.get_max_num_running_seqs() == 1:
359335
preemption_mode = PreemptionMode.RECOMPUTE
360336
else:
361337
preemption_mode = PreemptionMode.SWAP

0 commit comments

Comments
 (0)