Skip to content

Commit 5da9022

Browse files
heheda12345zhewenl
authored andcommitted
[Frontend] Optimize beam search performance by limiting concurrency (vllm-project#23599)
Signed-off-by: Chen Zhang <[email protected]>
1 parent fb4345f commit 5da9022

File tree

4 files changed

+143
-71
lines changed

4 files changed

+143
-71
lines changed

benchmarks/benchmark_throughput.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def run_vllm(
9696
end = time.perf_counter()
9797
else:
9898
assert lora_requests is None, "BeamSearch API does not support LoRA"
99-
prompts = [request.prompt for request in requests]
10099
# output_len should be the same for all requests.
101100
output_len = requests[0].expected_output_len
102101
for request in requests:

tests/conftest.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,15 +1022,17 @@ def generate_beam_search(
10221022
images: Optional[PromptImageInput] = None,
10231023
videos: Optional[PromptVideoInput] = None,
10241024
audios: Optional[PromptAudioInput] = None,
1025+
concurrency_limit: Optional[int] = None,
10251026
) -> list[tuple[list[list[int]], list[str]]]:
10261027
inputs = self.get_inputs(prompts,
10271028
images=images,
10281029
videos=videos,
10291030
audios=audios)
10301031

1031-
outputs = self.llm.beam_search(
1032-
inputs,
1033-
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
1032+
outputs = self.llm.beam_search(inputs,
1033+
BeamSearchParams(beam_width=beam_width,
1034+
max_tokens=max_tokens),
1035+
concurrency_limit=concurrency_limit)
10341036
returned_outputs = []
10351037
for output in outputs:
10361038
token_ids = [x.tokens for x in output.sequences]

tests/samplers/test_beam_search.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,59 @@ def test_beam_search_single_input(
6767
f"vLLM: {vllm_output_ids}")
6868

6969

70+
@pytest.mark.skip_v1 # FIXME: This fails on V1 right now.
71+
@pytest.mark.parametrize("model", MODELS)
72+
@pytest.mark.parametrize("dtype", ["half"])
73+
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
74+
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
75+
def test_beam_search_with_concurrency_limit(
76+
hf_runner,
77+
vllm_runner,
78+
example_prompts,
79+
model: str,
80+
dtype: str,
81+
max_tokens: int,
82+
beam_width: int,
83+
) -> None:
84+
# example_prompts[1]&[3]&[7] fails due to unknown reason even without
85+
# concurency limit. skip them for now.
86+
example_prompts = (example_prompts[:8])
87+
concurrency_limit = 2
88+
assert len(example_prompts) > concurrency_limit
89+
with vllm_runner(model, dtype=dtype) as vllm_model:
90+
outputs_with_limit = vllm_model.generate_beam_search(
91+
example_prompts,
92+
beam_width,
93+
max_tokens,
94+
concurrency_limit=concurrency_limit)
95+
outputs_without_limit = []
96+
97+
for i in range(0, len(example_prompts), concurrency_limit):
98+
outputs_without_limit.extend(
99+
vllm_model.generate_beam_search(
100+
example_prompts[i:i + concurrency_limit], beam_width,
101+
max_tokens))
102+
103+
correct = True
104+
for i in range(len(example_prompts)):
105+
output_ids_with_limit, output_texts_with_limit = outputs_with_limit[i]
106+
output_ids_without_limit, output_texts_without_limit = (
107+
outputs_without_limit[i])
108+
for j, (text_with_limit, text_without_limit) in enumerate(
109+
zip(output_texts_with_limit, output_texts_without_limit)):
110+
print(f">>>{j}-th with limit output:")
111+
print(text_with_limit)
112+
print(f">>>{j}-th without limit output:")
113+
print(text_without_limit)
114+
assert len(output_ids_with_limit) == len(output_ids_without_limit)
115+
for j in range(len(output_ids_with_limit)):
116+
if output_ids_with_limit[j] != output_ids_without_limit[j]:
117+
print(f"Test{i} output{j}:\n+limit: {output_ids_with_limit}\n"
118+
f"-limit: {output_ids_without_limit}")
119+
correct = False
120+
assert correct
121+
122+
70123
@pytest.mark.parametrize("dtype", ["half"])
71124
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
72125
@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)

vllm/entrypoints/llm.py

Lines changed: 85 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ def beam_search(
523523
params: BeamSearchParams,
524524
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
525525
use_tqdm: bool = False,
526+
concurrency_limit: Optional[int] = None,
526527
) -> list[BeamSearchOutput]:
527528
"""
528529
Generate sequences using beam search.
@@ -533,6 +534,8 @@ def beam_search(
533534
params: The beam search parameters.
534535
lora_request: LoRA request to use for generation, if any.
535536
use_tqdm: Whether to use tqdm to display the progress bar.
537+
concurrency_limit: The maximum number of concurrent requests.
538+
If None, the number of concurrent requests is unlimited.
536539
"""
537540
# TODO: how does beam search work together with length penalty,
538541
# frequency, penalty, and stopping criteria, etc.?
@@ -551,6 +554,15 @@ def beam_search(
551554
length_penalty,
552555
)
553556

557+
if use_tqdm and concurrency_limit is not None:
558+
logger.warning(
559+
"Progress bar is not supported when using concurrency_limit. "
560+
"Disabling progress bar.")
561+
use_tqdm = False
562+
563+
if concurrency_limit is None:
564+
concurrency_limit = len(prompts)
565+
554566
def create_tokens_prompt_from_beam(
555567
beam: BeamSearchSequence) -> TokensPrompt:
556568
token_prompt_kwargs: TokensPrompt = {
@@ -595,73 +607,79 @@ def create_tokens_prompt_from_beam(
595607
**mm_kwargs,
596608
), )
597609

598-
token_iter = range(max_tokens)
599-
if use_tqdm:
600-
token_iter = tqdm(token_iter,
601-
desc="Beam search",
602-
unit="token",
603-
unit_scale=False)
604-
logger.warning(
605-
"The progress bar shows the upper bound on token steps and "
606-
"may finish early due to stopping conditions. It does not "
607-
"reflect instance-level progress.")
608-
609-
for _ in token_iter:
610-
all_beams: list[BeamSearchSequence] = list(
611-
sum((instance.beams for instance in instances), []))
612-
pos = [0] + list(
613-
itertools.accumulate(
614-
len(instance.beams) for instance in instances))
615-
instance_start_and_end: list[tuple[int, int]] = list(
616-
zip(pos[:-1], pos[1:]))
617-
618-
if len(all_beams) == 0:
619-
break
620-
621-
# create the corresponding batch entries for prompt & optional lora
622-
prompts_batch, lora_req_batch = zip(
623-
*[(create_tokens_prompt_from_beam(beam), beam.lora_request)
624-
for beam in all_beams])
625-
626-
# only runs for one step
627-
# we don't need to use tqdm here
628-
output = self.generate(prompts_batch,
629-
sampling_params=beam_search_params,
630-
use_tqdm=False,
631-
lora_request=lora_req_batch)
632-
633-
for (start, end), instance in zip(instance_start_and_end,
634-
instances):
635-
instance_new_beams = []
636-
for i in range(start, end):
637-
current_beam = all_beams[i]
638-
result = output[i]
639-
640-
if result.outputs[0].logprobs is not None:
641-
# if `result.outputs[0].logprobs` is None, it means
642-
# the sequence is completed because of the max-model-len
643-
# or abortion. we don't need to add it to the new beams.
644-
logprobs = result.outputs[0].logprobs[0]
645-
for token_id, logprob_obj in logprobs.items():
646-
new_beam = BeamSearchSequence(
647-
tokens=current_beam.tokens + [token_id],
648-
logprobs=current_beam.logprobs + [logprobs],
649-
lora_request=current_beam.lora_request,
650-
cum_logprob=current_beam.cum_logprob +
651-
logprob_obj.logprob,
652-
multi_modal_data=current_beam.multi_modal_data,
653-
mm_processor_kwargs=current_beam.
654-
mm_processor_kwargs)
655-
656-
if token_id == tokenizer.eos_token_id and \
657-
not ignore_eos:
658-
instance.completed.append(new_beam)
659-
else:
660-
instance_new_beams.append(new_beam)
661-
sorted_beams = sorted(instance_new_beams,
662-
key=sort_beams_key,
663-
reverse=True)
664-
instance.beams = sorted_beams[:beam_width]
610+
for prompt_start in range(0, len(prompts), concurrency_limit):
611+
instances_batch = instances[prompt_start:prompt_start +
612+
concurrency_limit]
613+
614+
token_iter = range(max_tokens)
615+
if use_tqdm:
616+
token_iter = tqdm(token_iter,
617+
desc="Beam search",
618+
unit="token",
619+
unit_scale=False)
620+
logger.warning(
621+
"The progress bar shows the upper bound on token steps and "
622+
"may finish early due to stopping conditions. It does not "
623+
"reflect instance-level progress.")
624+
for _ in token_iter:
625+
all_beams: list[BeamSearchSequence] = list(
626+
sum((instance.beams for instance in instances_batch), []))
627+
pos = [0] + list(
628+
itertools.accumulate(
629+
len(instance.beams) for instance in instances_batch))
630+
instance_start_and_end: list[tuple[int, int]] = list(
631+
zip(pos[:-1], pos[1:]))
632+
633+
if len(all_beams) == 0:
634+
break
635+
636+
# create corresponding batch entries for prompt & optional lora
637+
prompts_batch, lora_req_batch = zip(
638+
*[(create_tokens_prompt_from_beam(beam), beam.lora_request)
639+
for beam in all_beams])
640+
641+
# only runs for one step
642+
# we don't need to use tqdm here
643+
output = self.generate(prompts_batch,
644+
sampling_params=beam_search_params,
645+
use_tqdm=False,
646+
lora_request=lora_req_batch)
647+
648+
for (start, end), instance in zip(instance_start_and_end,
649+
instances_batch):
650+
instance_new_beams = []
651+
for i in range(start, end):
652+
current_beam = all_beams[i]
653+
result = output[i]
654+
655+
if result.outputs[0].logprobs is not None:
656+
# if `result.outputs[0].logprobs` is None, it means
657+
# the sequence is completed because of the
658+
# max-model-len or abortion. we don't need to add
659+
# it to the new beams.
660+
logprobs = result.outputs[0].logprobs[0]
661+
for token_id, logprob_obj in logprobs.items():
662+
new_beam = BeamSearchSequence(
663+
tokens=current_beam.tokens + [token_id],
664+
logprobs=current_beam.logprobs +
665+
[logprobs],
666+
lora_request=current_beam.lora_request,
667+
cum_logprob=current_beam.cum_logprob +
668+
logprob_obj.logprob,
669+
multi_modal_data=current_beam.
670+
multi_modal_data,
671+
mm_processor_kwargs=current_beam.
672+
mm_processor_kwargs)
673+
674+
if token_id == tokenizer.eos_token_id and \
675+
not ignore_eos:
676+
instance.completed.append(new_beam)
677+
else:
678+
instance_new_beams.append(new_beam)
679+
sorted_beams = sorted(instance_new_beams,
680+
key=sort_beams_key,
681+
reverse=True)
682+
instance.beams = sorted_beams[:beam_width]
665683

666684
outputs = []
667685
for instance in instances:

0 commit comments

Comments
 (0)