Skip to content

Commit b0f4a6e

Browse files
committed
add preemption handling to CI
Signed-off-by: Youlei Yang <youlei.yang@intel.com>
1 parent aeed597 commit b0f4a6e

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

tests/full_tests/ci_gsm8k_tests.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,13 @@ run_qwen2_5_vl_unified_attn_test() {
265265
echo "✅ Test multimodal-support + unified attention with qwen2.5-vl-7b passed."
266266
}
267267

268+
# Preemption test
269+
run_preemption_test() {
270+
echo "➡️ Testing preemption handling..."
271+
VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 python -u "${VLLM_GAUDI_PREFIX}/tests/full_tests/preemption.py"
272+
echo "✅ Test with preemption handling passed."
273+
}
274+
268275
# Spec decode with ngram
269276
run_spec_decode_ngram_test() {
270277
echo "➡️ Testing Spec-decode with ngram..."

tests/full_tests/preemption.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from vllm import LLM, SamplingParams
5+
6+
# Sample prompts.
7+
prompts = [
8+
"Hello, my name is",
9+
"The president of the United States is",
10+
"The capital of France is",
11+
"The future of AI is",
12+
]
13+
# Create a sampling params object.
14+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=512, ignore_eos=True)
15+
16+
17+
def main():
18+
# Create an LLM.
19+
llm = LLM(
20+
model="meta-llama/Meta-Llama-3-8B-Instruct",
21+
block_size=128,
22+
max_model_len=1024,
23+
max_num_batched_tokens=1024,
24+
gpu_memory_utilization=0.9,
25+
num_gpu_blocks_override=8,
26+
disable_log_stats=False,
27+
)
28+
# Generate texts from the prompts.
29+
# The output is a list of RequestOutput objects
30+
# that contain the prompt, generated text, and other information.
31+
outputs = llm.generate(prompts, sampling_params)
32+
# Print the outputs.
33+
print("\nGenerated Outputs:\n" + "-" * 60)
34+
for output in outputs:
35+
prompt = output.prompt
36+
generated_text = output.outputs[0].text
37+
print(f"Prompt: {prompt!r}")
38+
print(f"Output: {generated_text!r}")
39+
print("-" * 60)
40+
41+
42+
if __name__ == "__main__":
43+
main()

0 commit comments

Comments
 (0)