|
| 1 | +from typing import Optional |
| 2 | + |
| 3 | +from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams |
| 4 | +from vllm.lora.request import LoRARequest |
| 5 | + |
| 6 | + |
| 7 | +def create_test_prompts( |
| 8 | + lora_path: str |
| 9 | +) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: |
| 10 | + """Create a list of test prompts with their sampling parameters. |
| 11 | +
|
| 12 | + 2 requests for base model, 4 requests for the LoRA. We define 2 |
| 13 | + different LoRA adapters (using the same model for demo purposes). |
| 14 | + """ |
| 15 | + return [ |
| 16 | + ( |
| 17 | + "A robot may not injure a human being", |
| 18 | + SamplingParams( |
| 19 | + temperature=0.0, |
| 20 | + #logprobs=1, |
| 21 | + #prompt_logprobs=1, |
| 22 | + max_tokens=128), |
| 23 | + None), |
| 24 | + ( |
| 25 | + "To be or not to be,", |
| 26 | + SamplingParams( |
| 27 | + temperature=0.0, |
| 28 | + top_k=5, |
| 29 | + #presence_penalty=0.2, |
| 30 | + max_tokens=128), |
| 31 | + None), |
| 32 | + ( |
| 33 | + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 |
| 34 | + SamplingParams( |
| 35 | + temperature=0.0, |
| 36 | + #logprobs=1, |
| 37 | + #prompt_logprobs=1, |
| 38 | + max_tokens=128, |
| 39 | + stop_token_ids=[32003]), |
| 40 | + LoRARequest("sql-lora", 1, lora_path)), |
| 41 | + ( |
| 42 | + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 |
| 43 | + SamplingParams(temperature=0, |
| 44 | + max_tokens=128, |
| 45 | + stop_token_ids=[32003]), |
| 46 | + LoRARequest("sql-lora", 1, lora_path)), |
| 47 | + ( |
| 48 | + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 |
| 49 | + SamplingParams( |
| 50 | + temperature=0.0, |
| 51 | + #logprobs=1, |
| 52 | + #prompt_logprobs=1, |
| 53 | + max_tokens=128, |
| 54 | + stop_token_ids=[32003]), |
| 55 | + LoRARequest("sql-lora2", 2, lora_path)), |
| 56 | + ( |
| 57 | + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 |
| 58 | + SamplingParams(temperature=0, |
| 59 | + max_tokens=128, |
| 60 | + stop_token_ids=[32003]), |
| 61 | + LoRARequest("sql-lora", 1, lora_path)), |
| 62 | + ] |
| 63 | + |
| 64 | + |
| 65 | +def process_requests(engine: LLMEngine, |
| 66 | + test_prompts: list[tuple[str, SamplingParams, |
| 67 | + Optional[LoRARequest]]]): |
| 68 | + """Continuously process a list of prompts and handle the outputs.""" |
| 69 | + request_id = 0 |
| 70 | + result = {} |
| 71 | + |
| 72 | + while test_prompts or engine.has_unfinished_requests(): |
| 73 | + if test_prompts: |
| 74 | + prompt, sampling_params, lora_request = test_prompts.pop(0) |
| 75 | + engine.add_request(str(request_id), |
| 76 | + prompt, |
| 77 | + sampling_params, |
| 78 | + lora_request=lora_request) |
| 79 | + request_id += 1 |
| 80 | + |
| 81 | + request_outputs: list[RequestOutput] = engine.step() |
| 82 | + |
| 83 | + for request_output in request_outputs: |
| 84 | + if request_output.finished: |
| 85 | + result[ |
| 86 | + request_output.request_id] = request_output.outputs[0].text |
| 87 | + return result |
| 88 | + |
| 89 | + |
| 90 | +expected_output = [ |
| 91 | + " or, through inaction, allow a human being to come to harm.\nA robot must obey the orders given it by human beings except where such orders would conflict with the First Law.\nA robot must protect its own existence as long as such protection does not conflict with the First or Second Law.\nThe Three Laws of Robotics were created by Isaac Asimov in 1942. They are the foundation of robotics and artificial intelligence.\nThe Three Laws of Robotics are the foundation of robotics and artificial intelligence. They were created by Isaac Asimov in 194", # noqa: E501 |
| 92 | + " that is the question.\nThe question is not whether you will be a leader, but whether you will be a good leader.\nThe question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The", # noqa: E501 |
| 93 | + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 |
| 94 | + " SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", # noqa: E501 |
| 95 | + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 |
| 96 | + " SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' " # noqa: E501 |
| 97 | +] |
| 98 | + |
| 99 | + |
| 100 | +def _test_llama_multilora(sql_lora_files, tp_size): |
| 101 | + """Main function that sets up and runs the prompt processing.""" |
| 102 | + engine_args = EngineArgs( |
| 103 | + model="/mnt/weka/data/pytorch/llama2/Llama-2-7b-hf", |
| 104 | + enable_lora=True, |
| 105 | + max_loras=2, |
| 106 | + max_lora_rank=8, |
| 107 | + max_num_seqs=256, |
| 108 | + dtype='bfloat16', |
| 109 | + tensor_parallel_size=tp_size) |
| 110 | + engine = LLMEngine.from_engine_args(engine_args) |
| 111 | + test_prompts = create_test_prompts(sql_lora_files) |
| 112 | + results = process_requests(engine, test_prompts) |
| 113 | + generated_texts = [results[key] for key in sorted(results)] |
| 114 | + assert generated_texts == expected_output |
| 115 | + |
| 116 | + |
| 117 | +def test_llama_multilora_1x(sql_lora_files): |
| 118 | + _test_llama_multilora(sql_lora_files, 1) |
| 119 | + |
| 120 | + |
| 121 | +#def test_llama_multilora_2x(sql_lora_files): |
| 122 | +# _test_llama_multilora(sql_lora_files, 2) |
| 123 | + |
| 124 | +#def test_llama_multilora_4x(sql_lora_files): |
| 125 | +# _test_llama_multilora(sql_lora_files, 4) |
0 commit comments