Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
initialize_model_parallel)
import pytest
import tempfile
from huggingface_hub import snapshot_download


@pytest.fixture
Expand All @@ -18,3 +19,14 @@ def dist_init():
initialize_model_parallel(1, 1)
yield
cleanup_dist_env_and_memory()


@pytest.fixture(scope="session")
def sql_lora_huggingface_id():
# huggingface repo id is used to test lora runtime downloading.
return "yard1/llama-2-7b-sql-lora-test"


@pytest.fixture(scope="session")
def sql_lora_files(sql_lora_huggingface_id):
return snapshot_download(repo_id=sql_lora_huggingface_id)
126 changes: 126 additions & 0 deletions tests/unit_tests/lora/test_llama_multilora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import Optional

from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest

MODEL_PATH = "/mnt/weka/data/pytorch/llama2/Llama-2-7b-hf"


def create_test_prompts(
lora_path: str
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
"""Create a list of test prompts with their sampling parameters.

2 requests for base model, 4 requests for the LoRA. We define 2
different LoRA adapters (using the same model for demo purposes).
"""
return [
(
"A robot may not injure a human being",
SamplingParams(
temperature=0.0,
#logprobs=1,
#prompt_logprobs=1,
max_tokens=128),
None),
(
"To be or not to be,",
SamplingParams(
temperature=0.0,
top_k=5,
#presence_penalty=0.2,
max_tokens=128),
None),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(
temperature=0.0,
#logprobs=1,
#prompt_logprobs=1,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(
temperature=0.0,
#logprobs=1,
#prompt_logprobs=1,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora2", 2, lora_path)),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
]


def process_requests(engine: LLMEngine,
test_prompts: list[tuple[str, SamplingParams,
Optional[LoRARequest]]]):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
result = {}

while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
lora_request=lora_request)
request_id += 1

request_outputs: list[RequestOutput] = engine.step()

for request_output in request_outputs:
if request_output.finished:
result[
request_output.request_id] = request_output.outputs[0].text
return result


expected_output = [
" or, through inaction, allow a human being to come to harm.\nA robot must obey the orders given it by human beings except where such orders would conflict with the First Law.\nA robot must protect its own existence as long as such protection does not conflict with the First or Second Law.\nThe Three Laws of Robotics were created by Isaac Asimov in 1942. They are the foundation of robotics and artificial intelligence.\nThe Three Laws of Robotics are the foundation of robotics and artificial intelligence. They were created by Isaac Asimov in 194", # noqa: E501
" that is the question.\nThe question is not whether you will be a leader, but whether you will be a good leader.\nThe question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The", # noqa: E501
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", # noqa: E501
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' " # noqa: E501
]


def _test_llama_multilora(sql_lora_files, tp_size):
"""Main function that sets up and runs the prompt processing."""
engine_args = EngineArgs(model=MODEL_PATH,
enable_lora=True,
max_loras=2,
max_lora_rank=8,
max_num_seqs=256,
dtype='bfloat16',
tensor_parallel_size=tp_size)
engine = LLMEngine.from_engine_args(engine_args)
test_prompts = create_test_prompts(sql_lora_files)
results = process_requests(engine, test_prompts)
generated_texts = [results[key] for key in sorted(results)]
assert generated_texts == expected_output


def test_llama_multilora_1x(sql_lora_files):
_test_llama_multilora(sql_lora_files, 1)


#def test_llama_multilora_2x(sql_lora_files):
# _test_llama_multilora(sql_lora_files, 2)

#def test_llama_multilora_4x(sql_lora_files):
# _test_llama_multilora(sql_lora_files, 4)
206 changes: 206 additions & 0 deletions tests/unit_tests/lora/test_llama_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Union

import vllm
from vllm.lora.request import LoRARequest
#from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test

MODEL_PATH = "/mnt/weka/data/pytorch/llama2/Llama-2-7b-hf"

EXPECTED_NO_LORA_OUTPUT = [
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for", # noqa: E501
]
EXPECTED_LORA_OUTPUT = [
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
" SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", # noqa: E501
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩ok", # noqa: E501
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'minnesota lynx' ", # noqa: E501
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501
]


def do_sample(llm: vllm.LLM,
lora_path: str,
lora_id: int,
tensorizer_config_dict: Union[dict, None] = None) -> list[str]:
prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
]

sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=64,
skip_special_tokens=False,
stop=["[/assistant]"])

if tensorizer_config_dict is not None:
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(
str(lora_id),
lora_id,
lora_path,
tensorizer_config_dict=tensorizer_config_dict)
if lora_id else None)
else:
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


def generate_and_test(llm,
sql_lora_files,
tensorizer_config_dict: Union[dict, None] = None):
print("lora adapter created")
assert do_sample(llm,
sql_lora_files,
tensorizer_config_dict=tensorizer_config_dict,
lora_id=0) == EXPECTED_NO_LORA_OUTPUT

print("lora 1")
assert do_sample(llm,
sql_lora_files,
tensorizer_config_dict=tensorizer_config_dict,
lora_id=1) == EXPECTED_LORA_OUTPUT

print("no lora")
assert do_sample(llm,
sql_lora_files,
tensorizer_config_dict=tensorizer_config_dict,
lora_id=0) == EXPECTED_NO_LORA_OUTPUT

print("lora 2")
assert do_sample(llm,
sql_lora_files,
tensorizer_config_dict=tensorizer_config_dict,
lora_id=2) == EXPECTED_LORA_OUTPUT

print("removing lora")


#@create_new_process_for_each_test()
def test_llama_lora(sql_lora_files):

llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
# also test odd max_num_seqs
max_num_seqs=13,
max_loras=4,
dtype='bfloat16',
)
generate_and_test(llm, sql_lora_files)


'''@multi_gpu_test(num_gpus=4)
@create_new_process_for_each_test()
def test_llama_lora_tp4(sql_lora_files):

llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=4,
enable_chunked_prefill=True,
)
generate_and_test(llm, sql_lora_files)


@multi_gpu_test(num_gpus=4)
@create_new_process_for_each_test()
def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):

llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=4,
fully_sharded_loras=True,
enable_chunked_prefill=True,
)
generate_and_test(llm, sql_lora_files)


@multi_gpu_test(num_gpus=2)
@create_new_process_for_each_test()
def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
sql_lora_huggingface_id):

# Run the tensorizing of the LoRA adapter and the model in a subprocess
# to guarantee cleanup

tp_size = 2
model_name = "model-rank-%03d.tensors"

model_ref = MODEL_PATH
lora_path = sql_lora_huggingface_id
suffix = "test"
try:
result = subprocess.run([
sys.executable,
f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model",
MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size",
str(tp_size), "serialize", "--serialized-directory",
str(tmp_path), "--suffix", suffix, "--serialization-kwargs",
'{"limit_cpu_concurrency": 4}'
],
check=True,
capture_output=True,
text=True)
except subprocess.CalledProcessError as e:
print("Tensorizing failed.")
print("STDOUT:\n", e.stdout)
print("STDERR:\n", e.stderr)
raise

print("STDOUT:\n", result.stdout)

model_uri = tmp_path / "vllm" / model_ref / suffix / model_name
tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri))

loaded_llm = LLM(model=model_ref,
load_format="tensorizer",
enable_lora=True,
enforce_eager=True,
model_loader_extra_config=tensorizer_config,
max_num_seqs=13,
tensor_parallel_size=2,
max_loras=2)

tc_as_dict = tensorizer_config.to_serializable()

print("lora adapter created")
assert do_sample(loaded_llm,
sql_lora_files,
tensorizer_config_dict=tc_as_dict,
lora_id=0) == EXPECTED_NO_LORA_OUTPUT

print("lora 1")
assert do_sample(loaded_llm,
sql_lora_files,
tensorizer_config_dict=tc_as_dict,
lora_id=1) == EXPECTED_LORA_OUTPUT'''
Loading
Loading