Skip to content

Commit 3b22757

Browse files
committed
Fix for model load error in CI
Signed-off-by: Vivek <[email protected]>
1 parent 81e6e72 commit 3b22757

File tree

2 files changed

+30
-11
lines changed

2 files changed

+30
-11
lines changed

tests/unit_tests/lora/test_llama_multilora.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,18 @@
22

33
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
44
from vllm.lora.request import LoRARequest
5+
import os
6+
7+
# Need to create symlink to avoid long path error
8+
# thrown by HF Hub validation check. Downloading
9+
# model directly from Hub can be done but will need
10+
# adding HF token to repo secrets
11+
src = "/mnt/weka/data/pytorch/llama2/Llama-2-7b-hf"
12+
dst = "test_model"
13+
if os.path.islink(dst):
14+
os.remove(dst)
15+
os.symlink(src, dst)
16+
MODEL_PATH = dst
517

618

719
def create_test_prompts(
@@ -99,14 +111,13 @@ def process_requests(engine: LLMEngine,
99111

100112
def _test_llama_multilora(sql_lora_files, tp_size):
101113
"""Main function that sets up and runs the prompt processing."""
102-
engine_args = EngineArgs(
103-
model="/mnt/weka/data/pytorch/llama2/Llama-2-7b-hf",
104-
enable_lora=True,
105-
max_loras=2,
106-
max_lora_rank=8,
107-
max_num_seqs=256,
108-
dtype='bfloat16',
109-
tensor_parallel_size=tp_size)
114+
engine_args = EngineArgs(model=MODEL_PATH,
115+
enable_lora=True,
116+
max_loras=2,
117+
max_lora_rank=8,
118+
max_num_seqs=256,
119+
dtype='bfloat16',
120+
tensor_parallel_size=tp_size)
110121
engine = LLMEngine.from_engine_args(engine_args)
111122
test_prompts = create_test_prompts(sql_lora_files)
112123
results = process_requests(engine, test_prompts)

tests/unit_tests/lora/test_llama_tp.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,19 @@
44

55
import vllm
66
from vllm.lora.request import LoRARequest
7-
7+
import os
88
#from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test
99

10-
MODEL_PATH = "/mnt/weka/data/pytorch/llama2/Llama-2-7b-hf"
11-
#MODEL_PATH = "meta-llama/Llama-2-7b-hf"
10+
# Need to create symlink to avoid long path error
11+
# thrown by HF Hub validation check. Downloading
12+
# model directly from Hub can be done but will need
13+
# adding HF token to repo secrets
14+
src = "/mnt/weka/data/pytorch/llama2/Llama-2-7b-hf"
15+
dst = "test_model"
16+
if os.path.islink(dst):
17+
os.remove(dst)
18+
os.symlink(src, dst)
19+
MODEL_PATH = dst
1220

1321
EXPECTED_NO_LORA_OUTPUT = [
1422
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant", # noqa: E501

0 commit comments

Comments
 (0)