Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,7 @@ def num_gpus_available():

temp_dir = tempfile.gettempdir()
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_opt_unmodified_path = os.path.join(temp_dir, "dummy_opt_unmodified")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")

Expand All @@ -1101,6 +1102,19 @@ def dummy_opt_path():
return _dummy_opt_path


@pytest.fixture
def dummy_opt_unmodified_path():
json_path = os.path.join(_dummy_opt_unmodified_path, "config.json")
if not os.path.exists(_dummy_opt_unmodified_path):
snapshot_download(
repo_id="facebook/opt-125m",
local_dir=_dummy_opt_unmodified_path,
ignore_patterns=["*.bin", "*.bin.index.json", "*.pt", "*.h5", "*.msgpack"],
)
assert os.path.exists(json_path)
return _dummy_opt_unmodified_path


@pytest.fixture
def dummy_llava_path():
json_path = os.path.join(_dummy_llava_path, "config.json")
Expand Down
39 changes: 39 additions & 0 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
from typing import Optional

import pytest
import torch
Expand All @@ -26,6 +27,11 @@
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager, WorkerLoRAManager
from vllm.platforms import current_platform
from vllm.validation.plugins import (
ModelType,
ModelValidationPlugin,
ModelValidationPluginRegistry,
)

from .utils import create_peft_lora

Expand Down Expand Up @@ -714,3 +720,36 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
torch.testing.assert_close(
packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b
)


class MyModelValidator(ModelValidationPlugin):
def model_validation_needed(self, model_type: ModelType, model_path: str) -> bool:
return True

def validate_model(
self, model_type: ModelType, model_path: str, model: Optional[str] = None
) -> None:
raise BaseException("Model did not validate")


def test_worker_adapter_manager_security_policy(dist_init, dummy_model_gate_up):
my_model_validator = MyModelValidator()
ModelValidationPluginRegistry.register_plugin("test", my_model_validator)

lora_config = LoRAConfig(
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
)

model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)

worker_adapter_manager = WorkerLoRAManager(
vllm_config, "cpu", EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES
)
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)

mapping = LoRAMapping([], [])
with pytest.raises(BaseException, match="Model did not validate"):
worker_adapter_manager.set_active_adapters(
[LoRARequest("1", 1, "/not/used")], mapping
)
75 changes: 75 additions & 0 deletions tests/model_executor/model_loader/test_model_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional

import pytest
from torch import nn

from vllm.config import DeviceConfig, ModelConfig, VllmConfig
from vllm.config.load import LoadConfig
from vllm.model_executor.model_loader import get_model_loader, register_model_loader
from vllm.model_executor.model_loader.base_loader import BaseModelLoader, DownloadType
from vllm.validation.plugins import (
ModelType,
ModelValidationPlugin,
ModelValidationPluginRegistry,
)


@register_model_loader("custom_load_format")
class CustomModelLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig) -> None:
super().__init__(load_config)
self.download_type = None

def download_model(self, model_config: ModelConfig) -> None:
pass

def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
pass

def set_download_type(self, download_type: DownloadType) -> None:
"""Allow changing download_type"""
self.download_type = download_type

def get_download_type(self, model_name_or_path: str) -> Optional[DownloadType]:
return self.download_type


class MyModelValidator(ModelValidationPlugin):
def model_validation_needed(self, model_type: ModelType, model_path: str) -> bool:
return True

def validate_model(
self, model_type: ModelType, model_path: str, model: Optional[str] = None
) -> None:
raise BaseException("Model did not validate")


def test_register_model_loader(dist_init):
load_config = LoadConfig(load_format="custom_load_format")
custom_model_loader = get_model_loader(load_config)
assert isinstance(custom_model_loader, CustomModelLoader)

my_model_validator = MyModelValidator()
ModelValidationPluginRegistry.register_plugin("test", my_model_validator)

vllm_config = VllmConfig(
model_config=ModelConfig(),
device_config=DeviceConfig("auto"),
load_config=LoadConfig(),
)
with pytest.raises(RuntimeError):
custom_model_loader.load_model(vllm_config, vllm_config.model_config)

# have validate_model() called
custom_model_loader.set_download_type(DownloadType.LOCAL_FILE)

vllm_config = VllmConfig(
model_config=ModelConfig(),
device_config=DeviceConfig("cpu"),
load_config=LoadConfig(),
)
with pytest.raises(BaseException, match="Model did not validate"):
custom_model_loader.load_model(vllm_config, vllm_config.model_config)
49 changes: 49 additions & 0 deletions tests/v1/engine/test_engine_core_model_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional

import pytest

from vllm.engine.arg_utils import EngineArgs
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor
from vllm.validation.plugins import (
ModelType,
ModelValidationPlugin,
ModelValidationPluginRegistry,
)


class MyModelValidator(ModelValidationPlugin):
def model_validation_needed(self, model_type: ModelType, model_path: str) -> bool:
return True

def validate_model(
self, model_type: ModelType, model_path: str, model: Optional[str] = None
) -> None:
raise BaseException("Model did not validate")


def test_engine_core_model_validation(
monkeypatch: pytest.MonkeyPatch, dummy_opt_unmodified_path
):
my_model_validator = MyModelValidator()
ModelValidationPluginRegistry.register_plugin("test", my_model_validator)

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

engine_args = EngineArgs(model=dummy_opt_unmodified_path)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)

with set_default_torch_num_threads(1) and pytest.raises(
BaseException, match="Model did not validate"
):
EngineCore(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
5 changes: 4 additions & 1 deletion vllm/entrypoints/openai/serving_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ async def resolve_lora(self, lora_name: str) -> Union[LoRARequest, ErrorResponse
base_model_name = self.model_config.model
unique_id = self.lora_id_counter.inc(1)
found_adapter = False
reason = ""

# Try to resolve using available resolvers
for resolver in self.lora_resolvers:
Expand All @@ -275,13 +276,15 @@ async def resolve_lora(self, lora_name: str) -> Union[LoRARequest, ErrorResponse
resolver.__class__.__name__,
e,
)
reason = str(e)
continue

if found_adapter:
# An adapter was found, but all attempts to load it failed.
return create_error_response(
message=(
f"LoRA adapter '{lora_name}' was found but could not be loaded."
f"LoRA adapter '{lora_name}' was found "
f"but could not be loaded: {reason}"
),
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST,
Expand Down
8 changes: 8 additions & 0 deletions vllm/lora/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import msgspec

from vllm.validation.plugins import ModelType, ModelValidationPluginRegistry


class LoRARequest(
msgspec.Struct,
Expand Down Expand Up @@ -99,3 +101,9 @@ def __hash__(self) -> int:
identified by their names across engines.
"""
return hash(self.lora_name)

def validate(self) -> None:
"""Validate the LoRA adapter given its path."""
ModelValidationPluginRegistry.validate_model(
ModelType.MODEL_TYPE_LORA, self.lora_path
)
1 change: 1 addition & 0 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def create_lora_manager(
return lora_manager.model

def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
lora_request.validate()
try:
supported_lora_modules = self._adapter_manager.supported_lora_modules
packed_modules_mapping = self._adapter_manager.packed_modules_mapping
Expand Down
51 changes: 51 additions & 0 deletions vllm/model_executor/model_loader/base_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from abc import ABC, abstractmethod
from typing import Optional

import huggingface_hub
import torch
import torch.nn as nn

Expand All @@ -13,10 +16,18 @@
process_weights_after_loading,
set_default_torch_dtype,
)
from vllm.validation.plugins import ModelType, ModelValidationPluginRegistry

logger = init_logger(__name__)


class DownloadType(int, enum.Enum):
HUGGINGFACE_HUB = 1
LOCAL_FILE = 2
S3 = 3 # not currently supported
UNKNOWN = 4


class BaseModelLoader(ABC):
"""Base class for model loaders."""

Expand All @@ -34,6 +45,45 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
inplace weights loading for an already-initialized model"""
raise NotImplementedError

def get_download_type(self, model_name_or_path: str) -> Optional[DownloadType]:
"""Subclass must override this and return the download type it needs"""
return None

def download_all_files(
self, model: nn.Module, model_config: ModelConfig, load_config: LoadConfig
) -> Optional[str]:
"""Download all files. Ask the subclass for what type of download
it does; Huggingface is used so often, so download all files here."""
dt = self.get_download_type(model_config.model)
if dt == DownloadType.HUGGINGFACE_HUB:
return huggingface_hub.snapshot_download(
model_config.model,
allow_patterns=["*"],
cache_dir=self.load_config.download_dir,
revision=model_config.revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)
elif dt == DownloadType.LOCAL_FILE:
return model_config.model
return None

def validate_model(
self, model: nn.Module, model_config: ModelConfig, load_config: LoadConfig
) -> None:
"""If needed, validate the model after downloading _all_ its files."""
if ModelValidationPluginRegistry.model_validation_needed(
ModelType.MODEL_TYPE_AI_MODEL, model_config.model
):
folder = self.download_all_files(model, model_config, load_config)
if folder is None:
raise RuntimeError(
"Model validation could not be done due to "
"an unsupported download method."
)
ModelValidationPluginRegistry.validate_model(
ModelType.MODEL_TYPE_AI_MODEL, folder, model_config.model
)

def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig
) -> nn.Module:
Expand All @@ -51,6 +101,7 @@ def load_model(
)

logger.debug("Loading weights on %s ...", load_device)
self.validate_model(model, model_config, vllm_config.load_config)
# Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)
Expand Down
7 changes: 6 additions & 1 deletion vllm/model_executor/model_loader/bitsandbytes_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.base_loader import BaseModelLoader, DownloadType
from vllm.model_executor.model_loader.utils import ParamMapping, set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf,
Expand Down Expand Up @@ -820,3 +820,8 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:

def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)

def get_download_type(self, model_name_or_path: str) -> DownloadType:
if os.path.isdir(model_name_or_path):
return DownloadType.LOCAL_FILE
return DownloadType.HUGGINGFACE_HUB
7 changes: 6 additions & 1 deletion vllm/model_executor/model_loader/default_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.base_loader import BaseModelLoader, DownloadType
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf,
download_weights_from_hf,
Expand Down Expand Up @@ -319,3 +319,8 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}"
)

def get_download_type(self, model_name_or_path: str) -> DownloadType:
if os.path.isdir(model_name_or_path):
return DownloadType.LOCAL_FILE
return DownloadType.HUGGINGFACE_HUB
Loading