Skip to content

Commit 2ad7047

Browse files
yanguleiczhu15
authored andcommitted
fix preempted prompts (#928)
### Motivation The preempted prompts might failed to mitch the `num_computed_tokens < num_prompt_tokens` test and be treated as decoding then cause runtime error. ### Changes - add `_is_prompt()` to check if a request is prompt or not. - consider the `num_scheduled_tokens` to handle the preempted prompts. - add test for preemption handling to the CI. --------- Signed-off-by: Youlei Yang <youlei.yang@intel.com>
1 parent 37f2119 commit 2ad7047

File tree

3 files changed

+69
-18
lines changed

3 files changed

+69
-18
lines changed

tests/full_tests/ci_gsm8k_tests.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,13 @@ run_qwen3_vl_test() {
273273
echo "✅ Test with multimodal-support with qwen3-vl-32b passed."
274274
}
275275

276+
# Preemption test
277+
run_preemption_test() {
278+
echo "➡️ Testing preemption handling..."
279+
VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 python -u "${VLLM_GAUDI_PREFIX}/tests/full_tests/preemption.py"
280+
echo "✅ Test with preemption handling passed."
281+
}
282+
276283
# Spec decode with ngram
277284
run_spec_decode_ngram_test() {
278285
echo "➡️ Testing Spec-decode with ngram..."

tests/full_tests/preemption.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from vllm import LLM, SamplingParams
5+
6+
# Sample prompts.
7+
prompts = [
8+
"Hello, my name is",
9+
"The president of the United States is",
10+
"The capital of France is",
11+
"The future of AI is",
12+
]
13+
# Create a sampling params object.
14+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=512, ignore_eos=True)
15+
16+
17+
def main():
18+
# Create an LLM.
19+
llm = LLM(
20+
model="meta-llama/Meta-Llama-3-8B-Instruct",
21+
block_size=128,
22+
max_model_len=1024,
23+
max_num_batched_tokens=1024,
24+
gpu_memory_utilization=0.9,
25+
num_gpu_blocks_override=8, # to trigger preemption
26+
disable_log_stats=False,
27+
)
28+
# Generate texts from the prompts.
29+
# The output is a list of RequestOutput objects
30+
# that contain the prompt, generated text, and other information.
31+
outputs = llm.generate(prompts, sampling_params)
32+
# Print the outputs.
33+
print("\nGenerated Outputs:\n" + "-" * 60)
34+
for output in outputs:
35+
prompt = output.prompt
36+
generated_text = output.outputs[0].text
37+
print(f"Prompt: {prompt!r}")
38+
print(f"Output: {generated_text!r}")
39+
print("-" * 60)
40+
41+
42+
if __name__ == "__main__":
43+
main()

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,6 +1528,20 @@ def _get_num_decodes(self) -> int:
15281528
num_decodes += 1
15291529
return num_decodes
15301530

1531+
def _is_prompt(self, req_idx: int, scheduler_output: "SchedulerOutput") -> bool:
1532+
req_id = self.input_batch.req_ids[req_idx]
1533+
num_computed_tokens = int(self.input_batch.num_computed_tokens_cpu[req_idx])
1534+
num_prompt_tokens = int(self.input_batch.num_prompt_tokens[req_idx])
1535+
num_scheduled_tokens = scheduler_output.num_scheduled_tokens.get(req_id)
1536+
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
1537+
1538+
num_decode_tokens = 1 if spec_decode_tokens is None else len(spec_decode_tokens) + 1
1539+
is_prompt = num_computed_tokens < num_prompt_tokens # normal prompt
1540+
is_prompt = is_prompt or num_scheduled_tokens > num_decode_tokens # maybe preempted prompt
1541+
is_prompt = is_prompt and not self.is_decoder_only(req_id)
1542+
1543+
return is_prompt
1544+
15311545
def _get_prompts_and_decodes(
15321546
self,
15331547
scheduler_output: "SchedulerOutput",
@@ -1551,7 +1565,6 @@ def _get_prompts_and_decodes(
15511565
requests = scheduler_output.kv_connector_metadata.requests
15521566
else:
15531567
requests = None
1554-
15551568
# Traverse decodes first
15561569
decode_req_ids = []
15571570
num_computed_tokens_decode = []
@@ -1565,19 +1578,11 @@ def _get_prompts_and_decodes(
15651578
self.input_batch.req_type[req_id] = requests_type[req_id]
15661579
break
15671580

1568-
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
1569-
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
1570-
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
1571-
if num_computed_tokens < num_prompt_tokens and \
1572-
not self.is_decoder_only(req_id):
1573-
# This is prompt
1581+
if self._is_prompt(i, scheduler_output):
15741582
break
15751583

1576-
# This is decode
1577-
# NOTE(chendi): To support spec decode,
1578-
# we don't assume num_scheduled_tokens == 1.
1579-
15801584
decode_req_ids.append(req_id)
1585+
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
15811586
num_computed_tokens_decode.append(int(num_computed_tokens + 1))
15821587

15831588
if self.profiler.enabled:
@@ -1590,16 +1595,12 @@ def _get_prompts_and_decodes(
15901595
req_id = self.input_batch.req_ids[i]
15911596
assert req_id is not None
15921597

1593-
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
1594-
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
1595-
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
1596-
15971598
# Must be prompt
1598-
assert num_computed_tokens < num_prompt_tokens
1599-
# NOTE(kzawora): In preempted sequences, num_output_tokens can be > 0, and still be a valid prefill
1599+
assert self._is_prompt(i, scheduler_output)
16001600

16011601
prompt_req_ids.append(req_id)
1602-
prompt_scheduled_tokens.append(num_scheduled_tokens)
1602+
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
1603+
prompt_scheduled_tokens.append(int(num_scheduled_tokens))
16031604

16041605
return PromptDecodeInfo(prompt_req_ids, decode_req_ids, prompt_scheduled_tokens)
16051606

0 commit comments

Comments
 (0)