Skip to content

Commit 5fd019e

Browse files
committed
Initial commit to add LoRA support
Remove dependency on LoRA worker class First working version with simple example Fixed BS>1 case Fix in platform.py to avoid error due to missing vllm_config Fix No LoRA case Fix warmup with LoRA Minor Cleanup Disable HPU Graphs Clean-up. Minor fixes Signed-off-by: Vivek <[email protected]> Add LoRA unit-test Signed-off-by: Vivek <[email protected]> Move LoRA configuration code to separate function Signed-off-by: Vivek <[email protected]> Add Multilora test Signed-off-by: Vivek <[email protected]> Fix mypy error Signed-off-by: Vivek <[email protected]> Update hpu_lora to use patching Signed-off-by: Vivek <[email protected]> Fix for model load error in CI Signed-off-by: Vivek <[email protected]>
1 parent f331b00 commit 5fd019e

File tree

6 files changed

+670
-23
lines changed

6 files changed

+670
-23
lines changed

tests/unit_tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
initialize_model_parallel)
44
import pytest
55
import tempfile
6+
from huggingface_hub import snapshot_download
67

78

89
@pytest.fixture
@@ -18,3 +19,14 @@ def dist_init():
1819
initialize_model_parallel(1, 1)
1920
yield
2021
cleanup_dist_env_and_memory()
22+
23+
24+
@pytest.fixture(scope="session")
25+
def sql_lora_huggingface_id():
26+
# huggingface repo id is used to test lora runtime downloading.
27+
return "yard1/llama-2-7b-sql-lora-test"
28+
29+
30+
@pytest.fixture(scope="session")
31+
def sql_lora_files(sql_lora_huggingface_id):
32+
return snapshot_download(repo_id=sql_lora_huggingface_id)
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from typing import Optional
2+
3+
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
4+
from vllm.lora.request import LoRARequest
5+
import os
6+
7+
# Need to create symlink to avoid long path error
8+
# thrown by HF Hub validation check. Downloading
9+
# model directly from Hub can be done but will need
10+
# adding HF token to repo secrets
11+
src = "/mnt/weka/data/pytorch/llama2/Llama-2-7b-hf"
12+
dst = "test_model"
13+
if os.path.islink(dst):
14+
os.remove(dst)
15+
os.symlink(src, dst)
16+
MODEL_PATH = dst
17+
18+
19+
def create_test_prompts(
20+
lora_path: str
21+
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
22+
"""Create a list of test prompts with their sampling parameters.
23+
24+
2 requests for base model, 4 requests for the LoRA. We define 2
25+
different LoRA adapters (using the same model for demo purposes).
26+
"""
27+
return [
28+
(
29+
"A robot may not injure a human being",
30+
SamplingParams(
31+
temperature=0.0,
32+
#logprobs=1,
33+
#prompt_logprobs=1,
34+
max_tokens=128),
35+
None),
36+
(
37+
"To be or not to be,",
38+
SamplingParams(
39+
temperature=0.0,
40+
top_k=5,
41+
#presence_penalty=0.2,
42+
max_tokens=128),
43+
None),
44+
(
45+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
46+
SamplingParams(
47+
temperature=0.0,
48+
#logprobs=1,
49+
#prompt_logprobs=1,
50+
max_tokens=128,
51+
stop_token_ids=[32003]),
52+
LoRARequest("sql-lora", 1, lora_path)),
53+
(
54+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
55+
SamplingParams(temperature=0,
56+
max_tokens=128,
57+
stop_token_ids=[32003]),
58+
LoRARequest("sql-lora", 1, lora_path)),
59+
(
60+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
61+
SamplingParams(
62+
temperature=0.0,
63+
#logprobs=1,
64+
#prompt_logprobs=1,
65+
max_tokens=128,
66+
stop_token_ids=[32003]),
67+
LoRARequest("sql-lora2", 2, lora_path)),
68+
(
69+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
70+
SamplingParams(temperature=0,
71+
max_tokens=128,
72+
stop_token_ids=[32003]),
73+
LoRARequest("sql-lora", 1, lora_path)),
74+
]
75+
76+
77+
def process_requests(engine: LLMEngine,
78+
test_prompts: list[tuple[str, SamplingParams,
79+
Optional[LoRARequest]]]):
80+
"""Continuously process a list of prompts and handle the outputs."""
81+
request_id = 0
82+
result = {}
83+
84+
while test_prompts or engine.has_unfinished_requests():
85+
if test_prompts:
86+
prompt, sampling_params, lora_request = test_prompts.pop(0)
87+
engine.add_request(str(request_id),
88+
prompt,
89+
sampling_params,
90+
lora_request=lora_request)
91+
request_id += 1
92+
93+
request_outputs: list[RequestOutput] = engine.step()
94+
95+
for request_output in request_outputs:
96+
if request_output.finished:
97+
result[
98+
request_output.request_id] = request_output.outputs[0].text
99+
return result
100+
101+
102+
expected_output = [
103+
" or, through inaction, allow a human being to come to harm.\nA robot must obey the orders given it by human beings except where such orders would conflict with the First Law.\nA robot must protect its own existence as long as such protection does not conflict with the First or Second Law.\nThe Three Laws of Robotics were created by Isaac Asimov in 1942. They are the foundation of robotics and artificial intelligence.\nThe Three Laws of Robotics are the foundation of robotics and artificial intelligence. They were created by Isaac Asimov in 194", # noqa: E501
104+
" that is the question.\nThe question is not whether you will be a leader, but whether you will be a good leader.\nThe question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The", # noqa: E501
105+
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
106+
" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", # noqa: E501
107+
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
108+
" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' " # noqa: E501
109+
]
110+
111+
112+
def _test_llama_multilora(sql_lora_files, tp_size):
113+
"""Main function that sets up and runs the prompt processing."""
114+
engine_args = EngineArgs(model=MODEL_PATH,
115+
enable_lora=True,
116+
max_loras=2,
117+
max_lora_rank=8,
118+
max_num_seqs=256,
119+
dtype='bfloat16',
120+
tensor_parallel_size=tp_size)
121+
engine = LLMEngine.from_engine_args(engine_args)
122+
test_prompts = create_test_prompts(sql_lora_files)
123+
results = process_requests(engine, test_prompts)
124+
generated_texts = [results[key] for key in sorted(results)]
125+
assert generated_texts == expected_output
126+
127+
128+
def test_llama_multilora_1x(sql_lora_files):
129+
_test_llama_multilora(sql_lora_files, 1)
130+
131+
132+
#def test_llama_multilora_2x(sql_lora_files):
133+
# _test_llama_multilora(sql_lora_files, 2)
134+
135+
#def test_llama_multilora_4x(sql_lora_files):
136+
# _test_llama_multilora(sql_lora_files, 4)
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Union
4+
5+
import vllm
6+
from vllm.lora.request import LoRARequest
7+
import os
8+
#from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test
9+
10+
# Need to create symlink to avoid long path error
11+
# thrown by HF Hub validation check. Downloading
12+
# model directly from Hub can be done but will need
13+
# adding HF token to repo secrets
14+
src = "/mnt/weka/data/pytorch/llama2/Llama-2-7b-hf"
15+
dst = "test_model"
16+
if os.path.islink(dst):
17+
os.remove(dst)
18+
os.symlink(src, dst)
19+
MODEL_PATH = dst
20+
21+
EXPECTED_NO_LORA_OUTPUT = [
22+
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant", # noqa: E501
23+
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501
24+
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501
25+
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio", # noqa: E501
26+
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501
27+
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for", # noqa: E501
28+
]
29+
EXPECTED_LORA_OUTPUT = [
30+
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
31+
" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", # noqa: E501
32+
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩ok", # noqa: E501
33+
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
34+
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'minnesota lynx' ", # noqa: E501
35+
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501
36+
]
37+
38+
39+
def do_sample(llm: vllm.LLM,
40+
lora_path: str,
41+
lora_id: int,
42+
tensorizer_config_dict: Union[dict, None] = None) -> list[str]:
43+
prompts = [
44+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
45+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
46+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501
47+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
48+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501
49+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
50+
]
51+
52+
sampling_params = vllm.SamplingParams(temperature=0,
53+
max_tokens=64,
54+
skip_special_tokens=False,
55+
stop=["[/assistant]"])
56+
57+
if tensorizer_config_dict is not None:
58+
outputs = llm.generate(
59+
prompts,
60+
sampling_params,
61+
lora_request=LoRARequest(
62+
str(lora_id),
63+
lora_id,
64+
lora_path,
65+
tensorizer_config_dict=tensorizer_config_dict)
66+
if lora_id else None)
67+
else:
68+
outputs = llm.generate(
69+
prompts,
70+
sampling_params,
71+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
72+
if lora_id else None)
73+
# Print the outputs.
74+
generated_texts: list[str] = []
75+
for output in outputs:
76+
prompt = output.prompt
77+
generated_text = output.outputs[0].text
78+
generated_texts.append(generated_text)
79+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
80+
return generated_texts
81+
82+
83+
def generate_and_test(llm,
84+
sql_lora_files,
85+
tensorizer_config_dict: Union[dict, None] = None):
86+
print("lora adapter created")
87+
assert do_sample(llm,
88+
sql_lora_files,
89+
tensorizer_config_dict=tensorizer_config_dict,
90+
lora_id=0) == EXPECTED_NO_LORA_OUTPUT
91+
92+
print("lora 1")
93+
assert do_sample(llm,
94+
sql_lora_files,
95+
tensorizer_config_dict=tensorizer_config_dict,
96+
lora_id=1) == EXPECTED_LORA_OUTPUT
97+
98+
print("no lora")
99+
assert do_sample(llm,
100+
sql_lora_files,
101+
tensorizer_config_dict=tensorizer_config_dict,
102+
lora_id=0) == EXPECTED_NO_LORA_OUTPUT
103+
104+
print("lora 2")
105+
assert do_sample(llm,
106+
sql_lora_files,
107+
tensorizer_config_dict=tensorizer_config_dict,
108+
lora_id=2) == EXPECTED_LORA_OUTPUT
109+
110+
print("removing lora")
111+
112+
113+
#@create_new_process_for_each_test()
114+
def test_llama_lora(sql_lora_files):
115+
116+
llm = vllm.LLM(
117+
MODEL_PATH,
118+
enable_lora=True,
119+
# also test odd max_num_seqs
120+
max_num_seqs=13,
121+
max_loras=4,
122+
dtype='bfloat16',
123+
)
124+
generate_and_test(llm, sql_lora_files)
125+
126+
127+
'''@multi_gpu_test(num_gpus=4)
128+
@create_new_process_for_each_test()
129+
def test_llama_lora_tp4(sql_lora_files):
130+
131+
llm = vllm.LLM(
132+
MODEL_PATH,
133+
enable_lora=True,
134+
max_num_seqs=16,
135+
max_loras=4,
136+
tensor_parallel_size=4,
137+
enable_chunked_prefill=True,
138+
)
139+
generate_and_test(llm, sql_lora_files)
140+
141+
142+
@multi_gpu_test(num_gpus=4)
143+
@create_new_process_for_each_test()
144+
def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
145+
146+
llm = vllm.LLM(
147+
MODEL_PATH,
148+
enable_lora=True,
149+
max_num_seqs=16,
150+
max_loras=4,
151+
tensor_parallel_size=4,
152+
fully_sharded_loras=True,
153+
enable_chunked_prefill=True,
154+
)
155+
generate_and_test(llm, sql_lora_files)
156+
157+
158+
@multi_gpu_test(num_gpus=2)
159+
@create_new_process_for_each_test()
160+
def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
161+
sql_lora_huggingface_id):
162+
163+
# Run the tensorizing of the LoRA adapter and the model in a subprocess
164+
# to guarantee cleanup
165+
166+
tp_size = 2
167+
model_name = "model-rank-%03d.tensors"
168+
169+
model_ref = MODEL_PATH
170+
lora_path = sql_lora_huggingface_id
171+
suffix = "test"
172+
try:
173+
result = subprocess.run([
174+
sys.executable,
175+
f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model",
176+
MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size",
177+
str(tp_size), "serialize", "--serialized-directory",
178+
str(tmp_path), "--suffix", suffix, "--serialization-kwargs",
179+
'{"limit_cpu_concurrency": 4}'
180+
],
181+
check=True,
182+
capture_output=True,
183+
text=True)
184+
except subprocess.CalledProcessError as e:
185+
print("Tensorizing failed.")
186+
print("STDOUT:\n", e.stdout)
187+
print("STDERR:\n", e.stderr)
188+
raise
189+
190+
print("STDOUT:\n", result.stdout)
191+
192+
model_uri = tmp_path / "vllm" / model_ref / suffix / model_name
193+
tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri))
194+
195+
loaded_llm = LLM(model=model_ref,
196+
load_format="tensorizer",
197+
enable_lora=True,
198+
enforce_eager=True,
199+
model_loader_extra_config=tensorizer_config,
200+
max_num_seqs=13,
201+
tensor_parallel_size=2,
202+
max_loras=2)
203+
204+
tc_as_dict = tensorizer_config.to_serializable()
205+
206+
print("lora adapter created")
207+
assert do_sample(loaded_llm,
208+
sql_lora_files,
209+
tensorizer_config_dict=tc_as_dict,
210+
lora_id=0) == EXPECTED_NO_LORA_OUTPUT
211+
212+
print("lora 1")
213+
assert do_sample(loaded_llm,
214+
sql_lora_files,
215+
tensorizer_config_dict=tc_as_dict,
216+
lora_id=1) == EXPECTED_LORA_OUTPUT'''

0 commit comments

Comments
 (0)