Skip to content

Commit ca06974

Browse files
committed
Add Multilora test
Signed-off-by: Vivek <[email protected]>
1 parent fb18a7e commit ca06974

File tree

1 file changed

+125
-0
lines changed

1 file changed

+125
-0
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from typing import Optional
2+
3+
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
4+
from vllm.lora.request import LoRARequest
5+
6+
7+
def create_test_prompts(
8+
lora_path: str
9+
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
10+
"""Create a list of test prompts with their sampling parameters.
11+
12+
2 requests for base model, 4 requests for the LoRA. We define 2
13+
different LoRA adapters (using the same model for demo purposes).
14+
"""
15+
return [
16+
(
17+
"A robot may not injure a human being",
18+
SamplingParams(
19+
temperature=0.0,
20+
#logprobs=1,
21+
#prompt_logprobs=1,
22+
max_tokens=128),
23+
None),
24+
(
25+
"To be or not to be,",
26+
SamplingParams(
27+
temperature=0.0,
28+
top_k=5,
29+
#presence_penalty=0.2,
30+
max_tokens=128),
31+
None),
32+
(
33+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
34+
SamplingParams(
35+
temperature=0.0,
36+
#logprobs=1,
37+
#prompt_logprobs=1,
38+
max_tokens=128,
39+
stop_token_ids=[32003]),
40+
LoRARequest("sql-lora", 1, lora_path)),
41+
(
42+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
43+
SamplingParams(temperature=0,
44+
max_tokens=128,
45+
stop_token_ids=[32003]),
46+
LoRARequest("sql-lora", 1, lora_path)),
47+
(
48+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
49+
SamplingParams(
50+
temperature=0.0,
51+
#logprobs=1,
52+
#prompt_logprobs=1,
53+
max_tokens=128,
54+
stop_token_ids=[32003]),
55+
LoRARequest("sql-lora2", 2, lora_path)),
56+
(
57+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
58+
SamplingParams(temperature=0,
59+
max_tokens=128,
60+
stop_token_ids=[32003]),
61+
LoRARequest("sql-lora", 1, lora_path)),
62+
]
63+
64+
65+
def process_requests(engine: LLMEngine,
66+
test_prompts: list[tuple[str, SamplingParams,
67+
Optional[LoRARequest]]]):
68+
"""Continuously process a list of prompts and handle the outputs."""
69+
request_id = 0
70+
result = {}
71+
72+
while test_prompts or engine.has_unfinished_requests():
73+
if test_prompts:
74+
prompt, sampling_params, lora_request = test_prompts.pop(0)
75+
engine.add_request(str(request_id),
76+
prompt,
77+
sampling_params,
78+
lora_request=lora_request)
79+
request_id += 1
80+
81+
request_outputs: list[RequestOutput] = engine.step()
82+
83+
for request_output in request_outputs:
84+
if request_output.finished:
85+
result[
86+
request_output.request_id] = request_output.outputs[0].text
87+
return result
88+
89+
90+
expected_output = [
91+
" or, through inaction, allow a human being to come to harm.\nA robot must obey the orders given it by human beings except where such orders would conflict with the First Law.\nA robot must protect its own existence as long as such protection does not conflict with the First or Second Law.\nThe Three Laws of Robotics were created by Isaac Asimov in 1942. They are the foundation of robotics and artificial intelligence.\nThe Three Laws of Robotics are the foundation of robotics and artificial intelligence. They were created by Isaac Asimov in 194", # noqa: E501
92+
" that is the question.\nThe question is not whether you will be a leader, but whether you will be a good leader.\nThe question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The", # noqa: E501
93+
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
94+
" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", # noqa: E501
95+
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
96+
" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' " # noqa: E501
97+
]
98+
99+
100+
def _test_llama_multilora(sql_lora_files, tp_size):
101+
"""Main function that sets up and runs the prompt processing."""
102+
engine_args = EngineArgs(
103+
model="/mnt/weka/data/pytorch/llama2/Llama-2-7b-hf",
104+
enable_lora=True,
105+
max_loras=2,
106+
max_lora_rank=8,
107+
max_num_seqs=256,
108+
dtype='bfloat16',
109+
tensor_parallel_size=tp_size)
110+
engine = LLMEngine.from_engine_args(engine_args)
111+
test_prompts = create_test_prompts(sql_lora_files)
112+
results = process_requests(engine, test_prompts)
113+
generated_texts = [results[key] for key in sorted(results)]
114+
assert generated_texts == expected_output
115+
116+
117+
def test_llama_multilora_1x(sql_lora_files):
118+
_test_llama_multilora(sql_lora_files, 1)
119+
120+
121+
#def test_llama_multilora_2x(sql_lora_files):
122+
# _test_llama_multilora(sql_lora_files, 2)
123+
124+
#def test_llama_multilora_4x(sql_lora_files):
125+
# _test_llama_multilora(sql_lora_files, 4)

0 commit comments

Comments
 (0)