Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/full_tests/ci_gsm8k_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,13 @@ run_qwen3_vl_test() {
echo "✅ Test with multimodal-support with qwen3-vl-32b passed."
}

# Preemption test
run_preemption_test() {
echo "➡️ Testing preemption handling..."
VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 python -u "${VLLM_GAUDI_PREFIX}/tests/full_tests/preemption.py"
echo "✅ Test with preemption handling passed."
}

# Spec decode with ngram
run_spec_decode_ngram_test() {
echo "➡️ Testing Spec-decode with ngram..."
Expand Down
43 changes: 43 additions & 0 deletions tests/full_tests/preemption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=512, ignore_eos=True)


def main():
# Create an LLM.
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
block_size=128,
max_model_len=1024,
max_num_batched_tokens=1024,
gpu_memory_utilization=0.9,
num_gpu_blocks_override=8, # to trigger preemption
disable_log_stats=False,
)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Output: {generated_text!r}")
print("-" * 60)


if __name__ == "__main__":
main()
37 changes: 19 additions & 18 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,6 +1477,20 @@ def _get_num_decodes(self) -> int:
num_decodes += 1
return num_decodes

def _is_prompt(self, req_idx: int, scheduler_output: "SchedulerOutput") -> bool:
req_id = self.input_batch.req_ids[req_idx]
num_computed_tokens = int(self.input_batch.num_computed_tokens_cpu[req_idx])
num_prompt_tokens = int(self.input_batch.num_prompt_tokens[req_idx])
num_scheduled_tokens = scheduler_output.num_scheduled_tokens.get(req_id)
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens.get(req_id)

num_decode_tokens = 1 if spec_decode_tokens is None else len(spec_decode_tokens) + 1
is_prompt = num_computed_tokens < num_prompt_tokens # normal prompt
is_prompt = is_prompt or num_scheduled_tokens > num_decode_tokens # maybe preempted prompt
is_prompt = is_prompt and not self.is_decoder_only(req_id)

return is_prompt

def _get_prompts_and_decodes(
self,
scheduler_output: "SchedulerOutput",
Expand All @@ -1500,7 +1514,6 @@ def _get_prompts_and_decodes(
requests = scheduler_output.kv_connector_metadata.requests
else:
requests = None

# Traverse decodes first
decode_req_ids = []
num_computed_tokens_decode = []
Expand All @@ -1514,19 +1527,11 @@ def _get_prompts_and_decodes(
self.input_batch.req_type[req_id] = requests_type[req_id]
break

num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
if num_computed_tokens < num_prompt_tokens and \
not self.is_decoder_only(req_id):
# This is prompt
if self._is_prompt(i, scheduler_output):
break

# This is decode
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is decode

why remove this comment?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a comment for the older statement of assert num_scheduled_tokens == 1 which is not relevant to the current codes anymore.

# NOTE(chendi): To support spec decode,
# we don't assume num_scheduled_tokens == 1.

decode_req_ids.append(req_id)
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
num_computed_tokens_decode.append(int(num_computed_tokens + 1))

if self.profiler.enabled:
Expand All @@ -1539,16 +1544,12 @@ def _get_prompts_and_decodes(
req_id = self.input_batch.req_ids[i]
assert req_id is not None

num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]

# Must be prompt
assert num_computed_tokens < num_prompt_tokens
# NOTE(kzawora): In preempted sequences, num_output_tokens can be > 0, and still be a valid prefill
assert self._is_prompt(i, scheduler_output)

prompt_req_ids.append(req_id)
prompt_scheduled_tokens.append(num_scheduled_tokens)
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
prompt_scheduled_tokens.append(int(num_scheduled_tokens))

return PromptDecodeInfo(prompt_req_ids, decode_req_ids, prompt_scheduled_tokens)

Expand Down