Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions docs/features/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,80 @@ The new format of `--lora-modules` is mainly to support the display of parent mo
]
}
```

## Default LoRA Models For Multimodal Models

Some models, e.g., [Granite Speech](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) and [Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) multimodal, contain LoRA adapter(s) that are expected to always be applied when a given modality is present. This can be a bit tedious to manage with the above approaches, as it requires the user to send the `LoRARequest` (offline) or to filter requests between the base model and LoRA model (server) depending on the content of the request's multimodal data.

To this end, we allow registration of default multimodal LoRAs to handle this automatically, where users can map each modality to a LoRA adapter to automatically apply it when the corresponding inputs are present. Note that currently, we only allow one LoRA per prompt; if several modalities are provided, each of which are registered to a given modality, none of them will be applied.

Example usage for offline inference:

```python
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset

model_id = "ibm-granite/granite-speech-3.3-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

def get_prompt(question: str, has_audio: bool):
"""Build the input prompt to send to vLLM."""
if has_audio:
question = f"<|audio|>{question}"
chat = [
{
"role": "user",
"content": question
}
]
return tokenizer.apply_chat_template(chat, tokenize=False)


model = LLM(
model=model_id,
enable_lora=True,
max_lora_rank=64,
max_model_len=2048,
limit_mm_per_prompt={"audio": 1},
# Will always pass a `LoRARequest` with the `model_id`
# whenever audio is contained in the request data.
default_mm_loras = {"audio": model_id},
enforce_eager=True,
)

question = "can you transcribe the speech into a written format?"
prompt_with_audio = get_prompt(
question=question,
has_audio=True,
)
audio = AudioAsset("mary_had_lamb").audio_and_sample_rate

inputs = {
"prompt": prompt_with_audio,
"multi_modal_data": {
"audio": audio,
}
}


outputs = model.generate(
inputs,
sampling_params=SamplingParams(
temperature=0.2,
max_tokens=64,
),
)
```

You can also pass a json dictionary of `--default-mm-loras` mapping modalities to LoRA model IDs. For example, when starting the server:

```bash
vllm serve ibm-granite/granite-speech-3.3-2b \
--max-model-len 2048 \
--enable-lora \
--default-mm-loras '{"audio":"ibm-granite/granite-speech-3.3-2b"}' \
--max-lora-rank 64
```

Note: Default multimodal LoRAs are currently only available for `.generate` and chat completions.
107 changes: 107 additions & 0 deletions tests/entrypoints/openai/test_default_mm_loras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os

import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from huggingface_hub import snapshot_download

from ...conftest import AudioTestAssets
from ...utils import RemoteOpenAIServer

# NOTE - the tests in this module are currently analogous to test_chat, but are
# separated to avoid OOM killing due to module-scoped servers, since we
# need a multimodal model for these tests.

# Contains a modality specific lora alongside the base model
MULTIMODAL_MODEL_NAME = snapshot_download(
"microsoft/Phi-4-multimodal-instruct")
AUDIO_LORA_PATH = os.path.join(MULTIMODAL_MODEL_NAME, "speech-lora")

ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501


@pytest.fixture(scope="module")
def monkeypatch_module():
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()


@pytest.fixture(scope="module", params=[False, True])
def multimodal_server(request, monkeypatch_module): # noqa: F811

use_v1 = request.param
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')

args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"half",
"--max-model-len",
"12800",
"--enforce-eager",
# lora config below
"--enable-lora",
"--lora-modules",
f"speech={AUDIO_LORA_PATH}",
"--max-lora-rank",
"320",
"--max-num-seqs",
"2",
"--trust-remote-code",
"--gpu-memory-utilization",
"0.8",
"--default-mm-loras",
f"{{\"audio\": \"{AUDIO_LORA_PATH}\"}}",
]

with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args) as remote_server:
yield remote_server


@pytest_asyncio.fixture
async def multi_modal_client(multimodal_server):
async with multimodal_server.get_async_client() as async_client:
yield async_client


@pytest.mark.asyncio
@pytest.mark.parametrize(
# base model with default lora should give the same response as lora model
"model_name",
[MULTIMODAL_MODEL_NAME, "speech"],
)
async def test_default_mm_lora_chat_completions(
model_name: str,
multi_modal_client: openai.AsyncOpenAI,
audio_assets: AudioTestAssets,
):
messages = [{
"role":
"user",
"content": [{
"type": "text",
"text": "Can you transcribe this audio?",
}, {
"type": "audio_url",
"audio_url": {
"url": audio_assets[0].url
},
}]
}]

chat_completion = await multi_modal_client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=128,
temperature=0.0)

assert len(chat_completion.choices) > 0

message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
assert message.content == ACTIVE_MM_LORA_RESPONSE
3 changes: 2 additions & 1 deletion tests/entrypoints/openai/test_serving_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ async def test_load_lora_adapter_success():
response = await serving_models.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
assert len(serving_models.lora_requests) == 1
assert serving_models.lora_requests[0].lora_name == "adapter"
assert "adapter" in serving_models.lora_requests
assert serving_models.lora_requests["adapter"].lora_name == "adapter"


@pytest.mark.asyncio
Expand Down
118 changes: 118 additions & 0 deletions tests/lora/test_default_mm_loras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for applying default registered multimodal loras.
"""

import os

from huggingface_hub import snapshot_download

from vllm.lora.request import LoRARequest

from ..conftest import AudioTestAssets, VllmRunner

MODEL_PATH = snapshot_download("microsoft/Phi-4-multimodal-instruct")
AUDIO_LORA_PATH = os.path.join(MODEL_PATH, "speech-lora")
IMAGE_LORA_PATH = os.path.join(MODEL_PATH, "vision-lora")

AUDIO_PROMPT = "<|user|><|audio_1|>Can you transcribe this audio?<|end|><|assistant|>" # noqa: E501

# Responses are greedy decoded; we just check the end of
# the generated text. If the lora is inactive, this model
# generates commentary on the transcription.
RESPONSE_SUFFIX_WITH_LORA = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
RESPONSE_SUFFIX_WITHOUT_LORA = "Certainly! Here is the transcription of the audio you provided:\n\nThe first words I spoke in the original phonograph record: A little piece of practical poetry. Mary had a little lamb; its fleece was white as snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501

VLLM_RUNNER_BASE_KWARGS = {
"model_name": MODEL_PATH,
"dtype": "half",
"enable_lora": "True",
"max_num_seqs": 2,
"max_lora_rank": 320,
"max_model_len": 12800,
"gpu_memory_utilization": 0.8,
"limit_mm_per_prompt": {
"audio": 1
},
"enforce_eager": True,
}


def run_test(vllm_runner, audio_assets, lora_request, expected_suffix,
**kwargs):
inputs = [([AUDIO_PROMPT], [audio_assets[0].audio_and_sample_rate[0]])]

# Apply any additional kwargs as overrides to the base kwargs
vllm_runner_kwargs = {**VLLM_RUNNER_BASE_KWARGS, **kwargs}

with vllm_runner(**vllm_runner_kwargs) as vllm_model:
vllm_outputs_with_default_lora = [
vllm_model.generate_greedy(
prompts,
max_tokens=128,
audios=audios,
lora_request=lora_request,
) for prompts, audios in inputs
]

assert vllm_outputs_with_default_lora[-1][-1][-1].endswith(
expected_suffix)


def test_active_default_mm_lora(
vllm_runner: type[VllmRunner],
audio_assets: AudioTestAssets,
):
"""Ensure that we can use the default audio lora."""
run_test(
vllm_runner,
audio_assets,
lora_request=None,
default_mm_loras={"audio": AUDIO_LORA_PATH},
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
)


def test_inactive_default_mm_lora(
vllm_runner: type[VllmRunner],
audio_assets: AudioTestAssets,
):
"""Ensure that modalities are filtered properly."""
# Default image lora won't be active since we only pass audio
run_test(
vllm_runner,
audio_assets,
lora_request=None,
default_mm_loras={"image": IMAGE_LORA_PATH},
expected_suffix=RESPONSE_SUFFIX_WITHOUT_LORA,
)


def test_default_mm_lora_succeeds_with_redundant_lora_request(
vllm_runner: type[VllmRunner],
audio_assets: AudioTestAssets,
):
"""Ensure that redundantly providing the lora works."""
run_test(
vllm_runner,
audio_assets,
lora_request=LoRARequest("audio", 1, AUDIO_LORA_PATH),
default_mm_loras={"audio": AUDIO_LORA_PATH},
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
)


def test_default_mm_lora_fails_with_overridden_lora_request(
vllm_runner: type[VllmRunner],
audio_assets: AudioTestAssets,
):
"""Ensure that if the lora_request conflicts with default_mm_loras,
we use the lora_request."""
run_test(
vllm_runner,
audio_assets,
lora_request=LoRARequest("speech", 2, AUDIO_LORA_PATH),
default_mm_loras={"audio": IMAGE_LORA_PATH},
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
)
11 changes: 11 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from vllm import version
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
Expand Down Expand Up @@ -2964,6 +2965,16 @@ class LoRAConfig:
trained with those scaling factors to be used at the same time. If not
specified, only adapters trained with the base model scaling factor are
allowed."""
default_mm_loras: Optional[dict[str, str]] = None
"""Dictionary mapping specific modalities to LoRA model paths; this field
is only applicable to multimodal models and should be leveraged when a
model always expects a LoRA to be active when a given modality is present.
Note that currently, if a request provides multiple additional
modalities, each of which have their own LoRA, we do NOT apply
default_mm_loras because we currently only support one lora adapter
per prompt. When run in offline mode, the lora IDs for n modalities
will be automatically assigned to 1-n with the names of the modalities
in alphabetic order."""
bias_enabled: bool = False
"""Enable bias for LoRA adapters."""

Expand Down
10 changes: 10 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ class EngineArgs:
enable_lora_bias: bool = LoRAConfig.bias_enabled
max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = LoRAConfig.max_lora_rank
default_mm_loras: Optional[Dict[str, str]] = \
LoRAConfig.default_mm_loras
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
Expand Down Expand Up @@ -791,6 +793,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
**lora_kwargs["max_cpu_loras"])
lora_group.add_argument("--fully-sharded-loras",
**lora_kwargs["fully_sharded_loras"])
lora_group.add_argument("--default-mm-loras",
**lora_kwargs["default_mm_loras"])

# PromptAdapter related configs
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
Expand Down Expand Up @@ -1245,10 +1249,16 @@ def create_engine_config(
disable_hybrid_kv_cache_manager,
)

if not model_config.is_multimodal_model and self.default_mm_loras:
raise ValueError(
"Default modality-specific LoRA(s) were provided for a "
"non multimodal model")

lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
default_mm_loras=self.default_mm_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
long_lora_scaling_factors=self.long_lora_scaling_factors,
Expand Down
Loading
Loading