diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 4036c7ec..1eae1a69 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -3,6 +3,7 @@ initialize_model_parallel) import pytest import tempfile +from huggingface_hub import snapshot_download @pytest.fixture @@ -18,3 +19,14 @@ def dist_init(): initialize_model_parallel(1, 1) yield cleanup_dist_env_and_memory() + + +@pytest.fixture(scope="session") +def sql_lora_huggingface_id(): + # huggingface repo id is used to test lora runtime downloading. + return "yard1/llama-2-7b-sql-lora-test" + + +@pytest.fixture(scope="session") +def sql_lora_files(sql_lora_huggingface_id): + return snapshot_download(repo_id=sql_lora_huggingface_id) diff --git a/tests/unit_tests/lora/test_llama_multilora.py b/tests/unit_tests/lora/test_llama_multilora.py new file mode 100644 index 00000000..4cf93ba4 --- /dev/null +++ b/tests/unit_tests/lora/test_llama_multilora.py @@ -0,0 +1,126 @@ +from typing import Optional + +from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams +from vllm.lora.request import LoRARequest + +MODEL_PATH = "/mnt/weka/data/pytorch/llama2/Llama-2-7b-hf" + + +def create_test_prompts( + lora_path: str +) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: + """Create a list of test prompts with their sampling parameters. + + 2 requests for base model, 4 requests for the LoRA. We define 2 + different LoRA adapters (using the same model for demo purposes). + """ + return [ + ( + "A robot may not injure a human being", + SamplingParams( + temperature=0.0, + #logprobs=1, + #prompt_logprobs=1, + max_tokens=128), + None), + ( + "To be or not to be,", + SamplingParams( + temperature=0.0, + top_k=5, + #presence_penalty=0.2, + max_tokens=128), + None), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + SamplingParams( + temperature=0.0, + #logprobs=1, + #prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + SamplingParams( + temperature=0.0, + #logprobs=1, + #prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora2", 2, lora_path)), + ( + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + SamplingParams(temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ] + + +def process_requests(engine: LLMEngine, + test_prompts: list[tuple[str, SamplingParams, + Optional[LoRARequest]]]): + """Continuously process a list of prompts and handle the outputs.""" + request_id = 0 + result = {} + + while test_prompts or engine.has_unfinished_requests(): + if test_prompts: + prompt, sampling_params, lora_request = test_prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + lora_request=lora_request) + request_id += 1 + + request_outputs: list[RequestOutput] = engine.step() + + for request_output in request_outputs: + if request_output.finished: + result[ + request_output.request_id] = request_output.outputs[0].text + return result + + +expected_output = [ + " or, through inaction, allow a human being to come to harm.\nA robot must obey the orders given it by human beings except where such orders would conflict with the First Law.\nA robot must protect its own existence as long as such protection does not conflict with the First or Second Law.\nThe Three Laws of Robotics were created by Isaac Asimov in 1942. They are the foundation of robotics and artificial intelligence.\nThe Three Laws of Robotics are the foundation of robotics and artificial intelligence. They were created by Isaac Asimov in 194", # noqa: E501 + " that is the question.\nThe question is not whether you will be a leader, but whether you will be a good leader.\nThe question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The question is not whether you will be a leader, but whether you will be a good leader. The", # noqa: E501 + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 + " SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", # noqa: E501 + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 + " SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' " # noqa: E501 +] + + +def _test_llama_multilora(sql_lora_files, tp_size): + """Main function that sets up and runs the prompt processing.""" + engine_args = EngineArgs(model=MODEL_PATH, + enable_lora=True, + max_loras=2, + max_lora_rank=8, + max_num_seqs=256, + dtype='bfloat16', + tensor_parallel_size=tp_size) + engine = LLMEngine.from_engine_args(engine_args) + test_prompts = create_test_prompts(sql_lora_files) + results = process_requests(engine, test_prompts) + generated_texts = [results[key] for key in sorted(results)] + assert generated_texts == expected_output + + +def test_llama_multilora_1x(sql_lora_files): + _test_llama_multilora(sql_lora_files, 1) + + +#def test_llama_multilora_2x(sql_lora_files): +# _test_llama_multilora(sql_lora_files, 2) + +#def test_llama_multilora_4x(sql_lora_files): +# _test_llama_multilora(sql_lora_files, 4) diff --git a/tests/unit_tests/lora/test_llama_tp.py b/tests/unit_tests/lora/test_llama_tp.py new file mode 100644 index 00000000..73cee0e0 --- /dev/null +++ b/tests/unit_tests/lora/test_llama_tp.py @@ -0,0 +1,206 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Union + +import vllm +from vllm.lora.request import LoRARequest +#from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test + +MODEL_PATH = "/mnt/weka/data/pytorch/llama2/Llama-2-7b-hf" + +EXPECTED_NO_LORA_OUTPUT = [ + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant", # noqa: E501 + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501 + "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501 + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio", # noqa: E501 + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501 + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for", # noqa: E501 +] +EXPECTED_LORA_OUTPUT = [ + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 + " SELECT nationality FROM table_name_11 WHERE elector = 'Anchero Pantaleone' ", # noqa: E501 + " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩ok", # noqa: E501 + " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501 + " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'minnesota lynx' ", # noqa: E501 + " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501 +] + + +def do_sample(llm: vllm.LLM, + lora_path: str, + lora_id: int, + tensorizer_config_dict: Union[dict, None] = None) -> list[str]: + prompts = [ + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 + ] + + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=64, + skip_special_tokens=False, + stop=["[/assistant]"]) + + if tensorizer_config_dict is not None: + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest( + str(lora_id), + lora_id, + lora_path, + tensorizer_config_dict=tensorizer_config_dict) + if lora_id else None) + else: + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts: list[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def generate_and_test(llm, + sql_lora_files, + tensorizer_config_dict: Union[dict, None] = None): + print("lora adapter created") + assert do_sample(llm, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=0) == EXPECTED_NO_LORA_OUTPUT + + print("lora 1") + assert do_sample(llm, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=1) == EXPECTED_LORA_OUTPUT + + print("no lora") + assert do_sample(llm, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=0) == EXPECTED_NO_LORA_OUTPUT + + print("lora 2") + assert do_sample(llm, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=2) == EXPECTED_LORA_OUTPUT + + print("removing lora") + + +#@create_new_process_for_each_test() +def test_llama_lora(sql_lora_files): + + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + # also test odd max_num_seqs + max_num_seqs=13, + max_loras=4, + dtype='bfloat16', + ) + generate_and_test(llm, sql_lora_files) + + +'''@multi_gpu_test(num_gpus=4) +@create_new_process_for_each_test() +def test_llama_lora_tp4(sql_lora_files): + + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=4, + enable_chunked_prefill=True, + ) + generate_and_test(llm, sql_lora_files) + + +@multi_gpu_test(num_gpus=4) +@create_new_process_for_each_test() +def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): + + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=4, + fully_sharded_loras=True, + enable_chunked_prefill=True, + ) + generate_and_test(llm, sql_lora_files) + + +@multi_gpu_test(num_gpus=2) +@create_new_process_for_each_test() +def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, + sql_lora_huggingface_id): + + # Run the tensorizing of the LoRA adapter and the model in a subprocess + # to guarantee cleanup + + tp_size = 2 + model_name = "model-rank-%03d.tensors" + + model_ref = MODEL_PATH + lora_path = sql_lora_huggingface_id + suffix = "test" + try: + result = subprocess.run([ + sys.executable, + f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model", + MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size", + str(tp_size), "serialize", "--serialized-directory", + str(tmp_path), "--suffix", suffix, "--serialization-kwargs", + '{"limit_cpu_concurrency": 4}' + ], + check=True, + capture_output=True, + text=True) + except subprocess.CalledProcessError as e: + print("Tensorizing failed.") + print("STDOUT:\n", e.stdout) + print("STDERR:\n", e.stderr) + raise + + print("STDOUT:\n", result.stdout) + + model_uri = tmp_path / "vllm" / model_ref / suffix / model_name + tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri)) + + loaded_llm = LLM(model=model_ref, + load_format="tensorizer", + enable_lora=True, + enforce_eager=True, + model_loader_extra_config=tensorizer_config, + max_num_seqs=13, + tensor_parallel_size=2, + max_loras=2) + + tc_as_dict = tensorizer_config.to_serializable() + + print("lora adapter created") + assert do_sample(loaded_llm, + sql_lora_files, + tensorizer_config_dict=tc_as_dict, + lora_id=0) == EXPECTED_NO_LORA_OUTPUT + + print("lora 1") + assert do_sample(loaded_llm, + sql_lora_files, + tensorizer_config_dict=tc_as_dict, + lora_id=1) == EXPECTED_LORA_OUTPUT''' diff --git a/vllm_gaudi/lora/punica_wrapper/punica_hpu.py b/vllm_gaudi/lora/punica_wrapper/punica_hpu.py new file mode 100644 index 00000000..69ea0bc5 --- /dev/null +++ b/vllm_gaudi/lora/punica_wrapper/punica_hpu.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union, final + +import torch +from vllm_gaudi.extension.ops import (dispatch_bgmv_embedding, + dispatch_bgmv_linear) + +from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase + + +@final +class PunicaWrapperHPU(PunicaWrapperBase): + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + # Increasing max_num_batched_tokens by 3x to handle increase in + # tensor size due to padding. + # TODO: Need to check if this override is still required + PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens, + max_batches, device) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + dispatch_bgmv_embedding(y, x, lora_b_stacked, 0) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + x = x.view(-1, x.shape[-1]) + offset_left = 0 + + for slice_idx in range(len(output_slices)): + dispatch_bgmv_linear( + y[:, offset_left:offset_left + output_slices[slice_idx]], x, + lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale) + offset_left += output_slices[slice_idx] + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale) + y = y.view_as(y_org) + + def add_shrink( + self, + y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ) -> None: + raise NotImplementedError + + def add_expand( + self, + y: torch.Tensor, + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> None: + raise NotImplementedError diff --git a/vllm_gaudi/ops/hpu_lora.py b/vllm_gaudi/ops/hpu_lora.py index 9d254ff1..58a2265c 100644 --- a/vllm_gaudi/ops/hpu_lora.py +++ b/vllm_gaudi/ops/hpu_lora.py @@ -1,31 +1,33 @@ import torch import torch.nn.functional as F -from vllm.model_executor.custom_op import CustomOp from vllm.lora.layers import VocabParallelEmbeddingWithLoRA +from vllm.lora import layers +from vllm.platforms import current_platform +from typing import Optional -@CustomOp.register_oot(name='VocabParallelEmbeddingWithLoRA') class HPUVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA): - def forward_oot(self, x: torch.Tensor) -> torch.Tensor: - # x need to reshaped into 2d as batch is there - # can be removed on moving to flat tensors - shape = x.shape - x = x.view(shape[0] * shape[1]) - + def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0) - embeddings_indices = torch.narrow( - self.punica_wrapper._embeddings_indices, 1, 0, x.size(0)) - indices = embeddings_indices[1] + # NB: Don't use torch.narrow here. torch.narrow triggers some + # Dynamic Shape specialization in torch.compile + # flatten to get num_tokens since HPU uses 2d input layout + # reshape indices_1, indices_0 to match shape of input + num_tokens = x.view(-1).shape[0] + indices_1 = self.punica_wrapper._embeddings_indices[ + 1][:num_tokens].view_as(x) + indices_0 = self.punica_wrapper._embeddings_indices[ + 0][:num_tokens].view_as(x) + full_lora_a_embeddings = F.embedding( - x + indices, + x + indices_1, self.lora_a_stacked_2d, ) - indices = embeddings_indices[0] full_output = self.base_layer.forward(x + - (indices * added_tokens_mask)) + (indices_0 * added_tokens_mask)) full_output_org = full_output if full_output.ndim == 3: @@ -37,11 +39,20 @@ def forward_oot(self, x: torch.Tensor) -> torch.Tensor: full_lora_a_embeddings.shape[1], -1, ) - self.punica_wrapper.add_lora_embedding(full_output, - full_lora_a_embeddings, - self.lora_b_stacked, - add_input=True) - # can be removed on moving to flat tensors - full_output_org = full_output_org.view(shape[0], shape[1], - full_output_org.shape[1]) + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_embedding( + full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) + + if not current_platform.can_update_inplace(): + full_output = lora_output + return full_output.view_as(full_output_org) + + +# refer to https://github.com/vllm-project/vllm/pull/21923 for more details +# on why this patching is needed. +layers.VocabParallelEmbeddingWithLoRA = HPUVocabParallelEmbeddingWithLoRA diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 3e078927..65d67116 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -15,6 +15,7 @@ import torch import torch.distributed import torch.nn.functional as F +import torch.nn as nn import vllm_gaudi.extension.environment as environment from vllm_gaudi.extension.bucketing.common import HPUBucketingManager from vllm_gaudi.extension.defragmentation import OnlineDefragmenter @@ -66,6 +67,12 @@ sanity_check_mm_encoder_outputs, scatter_mm_placeholders) from vllm.v1.sample.logits_processor import build_logitsprocs +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm_gaudi.extension.ops import LoraMask as LoraMask if TYPE_CHECKING: import xgrammar as xgr @@ -411,6 +418,10 @@ def forward(self, *args, **kwargs): # kwargs['attn_metadata'].slot_mapping, compared to untrimmed metadata kwargs = kwargs.copy() # selected_token_indices = kwargs.pop('selected_token_indices') + if 'lora_mask' in kwargs: + lora_mask = kwargs['lora_mask'] + LoraMask.setLoraMask(lora_mask) + kwargs.pop('lora_mask') if 'warmup_mode' in kwargs: kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] @@ -684,6 +695,115 @@ def __init__( self.defragmenter = OnlineDefragmenter() self.debug_fwd = init_debug_logger('fwd') + def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: list[int], + is_prompt: bool): + ''' + This is a helper function to create the mask for lora computations. + Lora Mask is needed to ensure we match the correct lora weights for the + for the request. + For Prompt phase we have + lora_mask with shape (batch_size * seq_len, max_loras * max_rank) + lora_logits_mask with shape (batch_size, max_loras * max_rank) + For Decode phase we have both + lora_mask and lora_logits_mask with shape + (batch_size, max_loras * max_rank) + ''' + lora_mask: torch.Tensor = None + lora_logits_mask: torch.Tensor = None + lora_index = 0 + + if self.lora_config: + if is_prompt: + lora_mask = torch.zeros( + input_tokens.shape[0] * input_tokens.shape[1], + (self.lora_config.max_loras) *\ + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + lora_logits_mask = torch.zeros( + input_tokens.shape[0], (self.lora_config.max_loras) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + + ones = torch.ones(input_tokens.shape[1], + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + logit_ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + + for i in range(len(lora_ids)): + if lora_ids[i] == 0: + continue + lora_index = self.lora_manager._adapter_manager.\ + lora_index_to_id.index(lora_ids[i]) + start_row = i * input_tokens.shape[1] + end_row = start_row + input_tokens.shape[1] + start_col = lora_index * self.lora_config.max_lora_rank + end_col = start_col + self.lora_config.max_lora_rank + lora_mask[start_row:end_row, start_col:end_col] = ones + lora_logits_mask[i, start_col:end_col] = logit_ones + lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_logits_mask.to('hpu') + else: + lora_mask = torch.zeros(input_tokens.shape[0], + (self.lora_config.max_loras) * + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + ones = torch.ones(1, + self.lora_config.max_lora_rank, + dtype=self.lora_config.lora_dtype) + for i in range(len(lora_ids)): + if lora_ids[i] == 0: + continue + lora_index = self.lora_manager._adapter_manager.\ + lora_index_to_id.index(lora_ids[i]) + start_pos = lora_index * self.lora_config.max_lora_rank + end_pos = start_pos + self.lora_config.max_lora_rank + lora_mask[i, start_pos:end_pos] = ones + lora_mask = lora_mask.to('hpu') + lora_logits_mask = lora_mask + + return lora_mask, lora_logits_mask + + def load_lora_model(self, model: nn.Module, model_config: ModelConfig, + scheduler_config: SchedulerConfig, + lora_config: LoRAConfig, device: str) -> nn.Module: + + if not supports_lora(model): + raise ValueError( + f"{model.__class__.__name__} does not support LoRA yet.") + + if supports_multimodal(model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + + # Use get_text_config() in case of multimodal models + text_config = model_config.hf_config.get_text_config() + + # Add LoRA Manager to the Model Runner + self.lora_manager = LRUCacheWorkerLoRAManager( + scheduler_config.max_num_seqs, + scheduler_config.max_num_batched_tokens, + model_config.get_vocab_size(), + lora_config, + device, + model.embedding_modules, + model.embedding_padding_modules, + max_position_embeddings=text_config.max_position_embeddings, + ) + return self.lora_manager.create_lora_manager(model) + + def set_active_loras(self, lora_requests: set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each @@ -1720,6 +1840,8 @@ def _execute_model_generic(self, attn_metadata, logits_indices, kv_caches, + lora_logits_mask, + lora_mask, warmup_mode=False, inputs_embeds=None, model_mm_kwargs=None): @@ -1754,13 +1876,15 @@ def _execute_model_generic(self, attn_metadata=trimmed_attn_metadata, kv_caches=kv_caches, inputs_embeds=inputs_embeds, - model_mm_kwargs=model_mm_kwargs) + model_mm_kwargs=model_mm_kwargs, + lora_mask=lora_mask) # NOTE(kzawora): returning hidden_states is required in prompt logprobs # scenarios, as they will do logit processing on their own non_flattened_hidden_states = hidden_states hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states[logits_indices] + LoraMask.setLoraMask(lora_logits_mask) with self.profiler.record_event('internal', ('compute_logits' f'{batch_size}_' f'seq{seq_len}_ctx' @@ -1922,6 +2046,51 @@ def apply_grammar_bitmask( logits.copy_( logits_cpu.to(self.device, non_blocking=True).to(logits.dtype)) + def _configure_lora(self, input, requests, req_ids, is_prompt): + lora_mask = None + lora_logits_mask = None + if self.lora_config: + if is_prompt: + lora_requests = [] if req_ids else requests + lora_ids = [] + lora_index_mapping = [] + lora_prompt_mapping = [] + for i, r_id in enumerate(req_ids): + lora_requests.append(requests[r_id].lora_request) + for lora_req in lora_requests: + lora_id = lora_req.lora_int_id if lora_req else 0 + lora_index_mapping += [lora_id] * (input.shape[1]) + #TODO: This may need to change when logprobs + # sampling is enabled + lora_prompt_mapping += [lora_id] + lora_ids.append(lora_id) + else: + lora_requests = [] + # lora_ids, lora_index_mapping, lora_prompt_mapping + # filled with 0 (indicating no lora) to account for + # any padding + lora_ids = [0] * input.shape[0] + lora_index_mapping = [0] * input.shape[0] + lora_prompt_mapping = [0] * input.shape[0] + for i, r_id in enumerate(req_ids): + lora_requests.append(requests[r_id].lora_request) + + for i, lora_req in enumerate(lora_requests): + lora_id = lora_req.lora_int_id if lora_req else 0 + lora_index_mapping[i] = lora_id + lora_prompt_mapping[i] = lora_id + lora_ids[i] = lora_id + + # is_prefill should always be "False" for HPU + lora_mapping = LoRAMapping(lora_index_mapping, + lora_prompt_mapping, + is_prefill=False) + self.set_active_loras(lora_requests, lora_mapping) + lora_mask, lora_logits_mask = self.create_lora_mask( + input, lora_ids, is_prompt) + + return lora_mask, lora_logits_mask + @torch.inference_mode() def execute_model( self, @@ -2060,6 +2229,9 @@ def execute_model( device=self.device, ) + lora_mask, lora_logits_mask = self._configure_lora( + token_ids, self.requests, req_id, True) + self.event_start = self.profiler.get_timestamp_us() self.profiler.start("internal", "prefill") # Align behavior of incomplete prompt with gpu_model_runner @@ -2075,6 +2247,8 @@ def execute_model( self._execute_model_generic( token_ids, position_ids, attn_metadata, logits_indices, self.kv_caches, + lora_logits_mask, + lora_mask, inputs_embeds=inputs_embeds, model_mm_kwargs=model_mm_kwargs, warmup_mode=warmup_mode) @@ -2115,9 +2289,12 @@ def execute_model( ######################### DECODES ######################### # Decodes run as one single batch with [padded_decode_bs, 1] if num_decodes > 0: + assert decode_data is not None + lora_mask, lora_logits_mask = self._configure_lora( + decode_data.token_ids, self.requests, pd_info.decode_req_ids, + False) self.event_start = self.profiler.get_timestamp_us() self.profiler.start("internal", "decode") - assert decode_data is not None htorch.core.mark_step() _, logits_device = self._execute_model_generic( decode_data.token_ids, @@ -2125,6 +2302,8 @@ def execute_model( decode_data.attn_metadata, decode_data.logits_indices, self.kv_caches, + lora_logits_mask, + lora_mask, warmup_mode=warmup_mode) htorch.core.mark_step() @@ -2280,6 +2459,12 @@ def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with HabanaMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) + if self.lora_config: + self.model = self.load_lora_model(self.model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device) self.model_memory_usage = m.consumed_device_memory logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30))