Skip to content

Commit 81bb860

Browse files
committed
Extend a LoRARequest to support model validation
Extend a LoRARequest with a validate() method to enable validation of a LoRA adapter when it is loaded. Add a test case. Signed-off-by: Stefan Berger <[email protected]>
1 parent 5a9007b commit 81bb860

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

tests/lora/test_lora_manager.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import os
5+
from typing import Optional
56

67
import pytest
78
import torch
@@ -26,6 +27,11 @@
2627
from vllm.lora.request import LoRARequest
2728
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager, WorkerLoRAManager
2829
from vllm.platforms import current_platform
30+
from vllm.validation.plugins import (
31+
ModelType,
32+
ModelValidationPlugin,
33+
ModelValidationPluginRegistry,
34+
)
2935

3036
from .utils import create_peft_lora
3137

@@ -714,3 +720,36 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
714720
torch.testing.assert_close(
715721
packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b
716722
)
723+
724+
725+
class MyModelValidator(ModelValidationPlugin):
726+
def model_validation_needed(self, model_type: ModelType, model_path: str) -> bool:
727+
return True
728+
729+
def validate_model(
730+
self, model_type: ModelType, model_path: str, model: Optional[str] = None
731+
) -> None:
732+
raise BaseException("Model did not validate")
733+
734+
735+
def test_worker_adapter_manager_security_policy(dist_init, dummy_model_gate_up):
736+
my_model_validator = MyModelValidator()
737+
ModelValidationPluginRegistry.register_plugin("test", my_model_validator)
738+
739+
lora_config = LoRAConfig(
740+
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
741+
)
742+
743+
model_config = ModelConfig(max_model_len=16)
744+
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
745+
746+
worker_adapter_manager = WorkerLoRAManager(
747+
vllm_config, "cpu", EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES
748+
)
749+
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)
750+
751+
mapping = LoRAMapping([], [])
752+
with pytest.raises(BaseException, match="Model did not validate"):
753+
worker_adapter_manager.set_active_adapters(
754+
[LoRARequest("1", 1, "/not/used")], mapping
755+
)

vllm/lora/request.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import msgspec
88

9+
from vllm.validation.plugins import ModelType, ModelValidationPluginRegistry
10+
911

1012
class LoRARequest(
1113
msgspec.Struct,
@@ -99,3 +101,9 @@ def __hash__(self) -> int:
99101
identified by their names across engines.
100102
"""
101103
return hash(self.lora_name)
104+
105+
def validate(self) -> None:
106+
"""Validate the LoRA adapter given its path."""
107+
ModelValidationPluginRegistry.validate_model(
108+
ModelType.MODEL_TYPE_LORA, self.lora_path
109+
)

vllm/lora/worker_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def create_lora_manager(
8585
return lora_manager.model
8686

8787
def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
88+
lora_request.validate()
8889
try:
8990
supported_lora_modules = self._adapter_manager.supported_lora_modules
9091
packed_modules_mapping = self._adapter_manager.packed_modules_mapping

0 commit comments

Comments
 (0)