Skip to content

Commit 9d92fb2

Browse files
committed
Add LoRA unit-test
Signed-off-by: Vivek <[email protected]>
1 parent d540834 commit 9d92fb2

File tree

4 files changed

+223
-5
lines changed

4 files changed

+223
-5
lines changed

tests/unit_tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
initialize_model_parallel)
44
import pytest
55
import tempfile
6+
from huggingface_hub import snapshot_download
67

78

89
@pytest.fixture
@@ -18,3 +19,14 @@ def dist_init():
1819
initialize_model_parallel(1, 1)
1920
yield
2021
cleanup_dist_env_and_memory()
22+
23+
24+
@pytest.fixture(scope="session")
25+
def sql_lora_huggingface_id():
26+
# huggingface repo id is used to test lora runtime downloading.
27+
return "yard1/llama-2-7b-sql-lora-test"
28+
29+
30+
@pytest.fixture(scope="session")
31+
def sql_lora_files(sql_lora_huggingface_id):
32+
return snapshot_download(repo_id=sql_lora_huggingface_id)
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Union
4+
5+
import vllm
6+
from vllm.lora.request import LoRARequest
7+
8+
#from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test
9+
10+
MODEL_PATH = "/mnt/weka/data/pytorch/llama2/Llama-2-7b-hf"
11+
#MODEL_PATH = "meta-llama/Llama-2-7b-hf"
12+
13+
EXPECTED_NO_LORA_OUTPUT = [
14+
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant", # noqa: E501
15+
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501
16+
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501
17+
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio", # noqa: E501
18+
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501
19+
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for", # noqa: E501
20+
]
21+
EXPECTED_LORA_OUTPUT = [
22+
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
23+
" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", # noqa: E501
24+
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩ok", # noqa: E501
25+
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
26+
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'minnesota lynx' ", # noqa: E501
27+
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501
28+
]
29+
30+
31+
def do_sample(llm: vllm.LLM,
32+
lora_path: str,
33+
lora_id: int,
34+
tensorizer_config_dict: Union[dict, None] = None) -> list[str]:
35+
prompts = [
36+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
37+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
38+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501
39+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
40+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501
41+
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
42+
]
43+
44+
sampling_params = vllm.SamplingParams(temperature=0,
45+
max_tokens=64,
46+
skip_special_tokens=False,
47+
stop=["[/assistant]"])
48+
49+
if tensorizer_config_dict is not None:
50+
outputs = llm.generate(
51+
prompts,
52+
sampling_params,
53+
lora_request=LoRARequest(
54+
str(lora_id),
55+
lora_id,
56+
lora_path,
57+
tensorizer_config_dict=tensorizer_config_dict)
58+
if lora_id else None)
59+
else:
60+
outputs = llm.generate(
61+
prompts,
62+
sampling_params,
63+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
64+
if lora_id else None)
65+
# Print the outputs.
66+
generated_texts: list[str] = []
67+
for output in outputs:
68+
prompt = output.prompt
69+
generated_text = output.outputs[0].text
70+
generated_texts.append(generated_text)
71+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
72+
return generated_texts
73+
74+
75+
def generate_and_test(llm,
76+
sql_lora_files,
77+
tensorizer_config_dict: Union[dict, None] = None):
78+
print("lora adapter created")
79+
assert do_sample(llm,
80+
sql_lora_files,
81+
tensorizer_config_dict=tensorizer_config_dict,
82+
lora_id=0) == EXPECTED_NO_LORA_OUTPUT
83+
84+
print("lora 1")
85+
assert do_sample(llm,
86+
sql_lora_files,
87+
tensorizer_config_dict=tensorizer_config_dict,
88+
lora_id=1) == EXPECTED_LORA_OUTPUT
89+
90+
print("no lora")
91+
assert do_sample(llm,
92+
sql_lora_files,
93+
tensorizer_config_dict=tensorizer_config_dict,
94+
lora_id=0) == EXPECTED_NO_LORA_OUTPUT
95+
96+
print("lora 2")
97+
assert do_sample(llm,
98+
sql_lora_files,
99+
tensorizer_config_dict=tensorizer_config_dict,
100+
lora_id=2) == EXPECTED_LORA_OUTPUT
101+
102+
print("removing lora")
103+
104+
105+
#@create_new_process_for_each_test()
106+
def test_llama_lora(sql_lora_files):
107+
108+
llm = vllm.LLM(
109+
MODEL_PATH,
110+
enable_lora=True,
111+
# also test odd max_num_seqs
112+
max_num_seqs=13,
113+
max_loras=4,
114+
dtype='bfloat16',
115+
)
116+
generate_and_test(llm, sql_lora_files)
117+
118+
119+
'''@multi_gpu_test(num_gpus=4)
120+
@create_new_process_for_each_test()
121+
def test_llama_lora_tp4(sql_lora_files):
122+
123+
llm = vllm.LLM(
124+
MODEL_PATH,
125+
enable_lora=True,
126+
max_num_seqs=16,
127+
max_loras=4,
128+
tensor_parallel_size=4,
129+
enable_chunked_prefill=True,
130+
)
131+
generate_and_test(llm, sql_lora_files)
132+
133+
134+
@multi_gpu_test(num_gpus=4)
135+
@create_new_process_for_each_test()
136+
def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
137+
138+
llm = vllm.LLM(
139+
MODEL_PATH,
140+
enable_lora=True,
141+
max_num_seqs=16,
142+
max_loras=4,
143+
tensor_parallel_size=4,
144+
fully_sharded_loras=True,
145+
enable_chunked_prefill=True,
146+
)
147+
generate_and_test(llm, sql_lora_files)
148+
149+
150+
@multi_gpu_test(num_gpus=2)
151+
@create_new_process_for_each_test()
152+
def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
153+
sql_lora_huggingface_id):
154+
155+
# Run the tensorizing of the LoRA adapter and the model in a subprocess
156+
# to guarantee cleanup
157+
158+
tp_size = 2
159+
model_name = "model-rank-%03d.tensors"
160+
161+
model_ref = MODEL_PATH
162+
lora_path = sql_lora_huggingface_id
163+
suffix = "test"
164+
try:
165+
result = subprocess.run([
166+
sys.executable,
167+
f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model",
168+
MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size",
169+
str(tp_size), "serialize", "--serialized-directory",
170+
str(tmp_path), "--suffix", suffix, "--serialization-kwargs",
171+
'{"limit_cpu_concurrency": 4}'
172+
],
173+
check=True,
174+
capture_output=True,
175+
text=True)
176+
except subprocess.CalledProcessError as e:
177+
print("Tensorizing failed.")
178+
print("STDOUT:\n", e.stdout)
179+
print("STDERR:\n", e.stderr)
180+
raise
181+
182+
print("STDOUT:\n", result.stdout)
183+
184+
model_uri = tmp_path / "vllm" / model_ref / suffix / model_name
185+
tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri))
186+
187+
loaded_llm = LLM(model=model_ref,
188+
load_format="tensorizer",
189+
enable_lora=True,
190+
enforce_eager=True,
191+
model_loader_extra_config=tensorizer_config,
192+
max_num_seqs=13,
193+
tensor_parallel_size=2,
194+
max_loras=2)
195+
196+
tc_as_dict = tensorizer_config.to_serializable()
197+
198+
print("lora adapter created")
199+
assert do_sample(loaded_llm,
200+
sql_lora_files,
201+
tensorizer_config_dict=tc_as_dict,
202+
lora_id=0) == EXPECTED_NO_LORA_OUTPUT
203+
204+
print("lora 1")
205+
assert do_sample(loaded_llm,
206+
sql_lora_files,
207+
tensorizer_config_dict=tc_as_dict,
208+
lora_id=1) == EXPECTED_LORA_OUTPUT'''

vllm_gaudi/platform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9595

9696
if (vllm_config.model_config is not None
9797
and vllm_config.model_config.dtype
98-
in (torch.float16,)):
98+
in (torch.float16, torch.float32)):
9999
logger.warning(
100100
"The HPU backend currently does not support %s. "
101101
"Using bfloat16 instead.", vllm_config.model_config.dtype)

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -453,12 +453,9 @@ def generate_proposals(self, *args, **kwargs):
453453

454454

455455
def _maybe_wrap_in_hpu_graph(*args, **kwargs):
456-
'''
457456
return htorch.hpu.wrap_in_hpu_graph(
458457
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
459458
) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs)
460-
'''
461-
return HpuModelAdapter(*args, **kwargs)
462459

463460

464461
def subtuple(obj: object,
@@ -2234,7 +2231,8 @@ def warmup_scenario(self,
22342231
htorch.core.mark_step()
22352232
_ = self._execute_model_generic(input_ids_device, position_ids_device,
22362233
attn_metadata, logits_indices_device,
2237-
kv_caches, lora_logits_mask, lora_mask, True)
2234+
kv_caches, lora_logits_mask, lora_mask,
2235+
True)
22382236
# TODO: do sampling on logits, warmup sampler and prefill joiner
22392237
htorch.core.mark_step()
22402238
self.profiler.end()

0 commit comments

Comments
 (0)