Skip to content

Commit f1bf91d

Browse files
committed
fix preempted prompts (vllm-project#928)
The preempted prompts might failed to mitch the `num_computed_tokens < num_prompt_tokens` test and be treated as decoding then cause runtime error. - add `_is_prompt()` to check if a request is prompt or not. - consider the `num_scheduled_tokens` to handle the preempted prompts. - add test for preemption handling to the CI. --------- Signed-off-by: Youlei Yang <youlei.yang@intel.com>
1 parent 6728857 commit f1bf91d

File tree

3 files changed

+69
-18
lines changed

3 files changed

+69
-18
lines changed

tests/full_tests/ci_e2e_discoverable_tests.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,13 @@ run_mistral3_test() {
351351
echo "✅ Test with multimodal-support with Mistral-Small-3.1-24B passed."
352352
}
353353

354+
# Preemption test
355+
run_preemption_test() {
356+
echo "➡️ Testing preemption handling..."
357+
VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 python -u "${VLLM_GAUDI_PREFIX}/tests/full_tests/preemption.py"
358+
echo "✅ Test with preemption handling passed."
359+
}
360+
354361
# Spec decode with ngram
355362
run_spec_decode_ngram_test() {
356363
echo "➡️ Testing Spec-decode with ngram..."

tests/full_tests/preemption.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from vllm import LLM, SamplingParams
5+
6+
# Sample prompts.
7+
prompts = [
8+
"Hello, my name is",
9+
"The president of the United States is",
10+
"The capital of France is",
11+
"The future of AI is",
12+
]
13+
# Create a sampling params object.
14+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=512, ignore_eos=True)
15+
16+
17+
def main():
18+
# Create an LLM.
19+
llm = LLM(
20+
model="meta-llama/Meta-Llama-3-8B-Instruct",
21+
block_size=128,
22+
max_model_len=1024,
23+
max_num_batched_tokens=1024,
24+
gpu_memory_utilization=0.9,
25+
num_gpu_blocks_override=8, # to trigger preemption
26+
disable_log_stats=False,
27+
)
28+
# Generate texts from the prompts.
29+
# The output is a list of RequestOutput objects
30+
# that contain the prompt, generated text, and other information.
31+
outputs = llm.generate(prompts, sampling_params)
32+
# Print the outputs.
33+
print("\nGenerated Outputs:\n" + "-" * 60)
34+
for output in outputs:
35+
prompt = output.prompt
36+
generated_text = output.outputs[0].text
37+
print(f"Prompt: {prompt!r}")
38+
print(f"Output: {generated_text!r}")
39+
print("-" * 60)
40+
41+
42+
if __name__ == "__main__":
43+
main()

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,6 +1635,20 @@ def _get_num_decodes(self) -> int:
16351635
num_decodes += 1
16361636
return num_decodes
16371637

1638+
def _is_prompt(self, req_idx: int, scheduler_output: "SchedulerOutput") -> bool:
1639+
req_id = self.input_batch.req_ids[req_idx]
1640+
num_computed_tokens = int(self.input_batch.num_computed_tokens_cpu[req_idx])
1641+
num_prompt_tokens = int(self.input_batch.num_prompt_tokens[req_idx])
1642+
num_scheduled_tokens = scheduler_output.num_scheduled_tokens.get(req_id)
1643+
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
1644+
1645+
num_decode_tokens = 1 if spec_decode_tokens is None else len(spec_decode_tokens) + 1
1646+
is_prompt = num_computed_tokens < num_prompt_tokens # normal prompt
1647+
is_prompt = is_prompt or num_scheduled_tokens > num_decode_tokens # maybe preempted prompt
1648+
is_prompt = is_prompt and not self.is_decoder_only(req_id)
1649+
1650+
return is_prompt
1651+
16381652
def _get_prompts_and_decodes(
16391653
self,
16401654
scheduler_output: "SchedulerOutput",
@@ -1679,7 +1693,6 @@ def _get_prompts_and_decodes(
16791693
requests = metadata.reqs_to_store | metadata.reqs_to_load
16801694
else:
16811695
requests = scheduler_output.kv_connector_metadata.requests
1682-
16831696
# Traverse decodes first
16841697
decode_req_ids = []
16851698
num_computed_tokens_decode = []
@@ -1693,19 +1706,11 @@ def _get_prompts_and_decodes(
16931706
self.input_batch.req_type[req_id] = requests_type[req_id]
16941707
break
16951708

1696-
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
1697-
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
1698-
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
1699-
if num_computed_tokens < num_prompt_tokens and \
1700-
not self.is_decoder_only(req_id):
1701-
# This is prompt
1709+
if self._is_prompt(i, scheduler_output):
17021710
break
17031711

1704-
# This is decode
1705-
# NOTE(chendi): To support spec decode,
1706-
# we don't assume num_scheduled_tokens == 1.
1707-
17081712
decode_req_ids.append(req_id)
1713+
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
17091714
num_computed_tokens_decode.append(int(num_computed_tokens + 1))
17101715

17111716
if self.profiler.enabled:
@@ -1718,16 +1723,12 @@ def _get_prompts_and_decodes(
17181723
req_id = self.input_batch.req_ids[i]
17191724
assert req_id is not None
17201725

1721-
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
1722-
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
1723-
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
1724-
17251726
# Must be prompt
1726-
assert num_computed_tokens < num_prompt_tokens
1727-
# NOTE(kzawora): In preempted sequences, num_output_tokens can be > 0, and still be a valid prefill
1727+
assert self._is_prompt(i, scheduler_output)
17281728

17291729
prompt_req_ids.append(req_id)
1730-
prompt_scheduled_tokens.append(num_scheduled_tokens)
1730+
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
1731+
prompt_scheduled_tokens.append(int(num_scheduled_tokens))
17311732

17321733
return PromptDecodeInfo(prompt_req_ids, decode_req_ids, prompt_scheduled_tokens)
17331734

0 commit comments

Comments
 (0)