diff --git a/tests/conftest.py b/tests/conftest.py index 4713e1238596..a1304a50bc58 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1079,6 +1079,7 @@ def num_gpus_available(): temp_dir = tempfile.gettempdir() _dummy_opt_path = os.path.join(temp_dir, "dummy_opt") +_dummy_opt_unmodified_path = os.path.join(temp_dir, "dummy_opt_unmodified") _dummy_llava_path = os.path.join(temp_dir, "dummy_llava") _dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding") @@ -1101,6 +1102,19 @@ def dummy_opt_path(): return _dummy_opt_path +@pytest.fixture +def dummy_opt_unmodified_path(): + json_path = os.path.join(_dummy_opt_unmodified_path, "config.json") + if not os.path.exists(_dummy_opt_unmodified_path): + snapshot_download( + repo_id="facebook/opt-125m", + local_dir=_dummy_opt_unmodified_path, + ignore_patterns=["*.bin", "*.bin.index.json", "*.pt", "*.h5", "*.msgpack"], + ) + assert os.path.exists(json_path) + return _dummy_opt_unmodified_path + + @pytest.fixture def dummy_llava_path(): json_path = os.path.join(_dummy_llava_path, "config.json") diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index e7816031142e..79a16cd9480c 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +from typing import Optional import pytest import torch @@ -26,6 +27,11 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager, WorkerLoRAManager from vllm.platforms import current_platform +from vllm.validation.plugins import ( + ModelType, + ModelValidationPlugin, + ModelValidationPluginRegistry, +) from .utils import create_peft_lora @@ -714,3 +720,36 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): torch.testing.assert_close( packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b ) + + +class MyModelValidator(ModelValidationPlugin): + def model_validation_needed(self, model_type: ModelType, model_path: str) -> bool: + return True + + def validate_model( + self, model_type: ModelType, model_path: str, model: Optional[str] = None + ) -> None: + raise BaseException("Model did not validate") + + +def test_worker_adapter_manager_security_policy(dist_init, dummy_model_gate_up): + my_model_validator = MyModelValidator() + ModelValidationPluginRegistry.register_plugin("test", my_model_validator) + + lora_config = LoRAConfig( + max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE + ) + + model_config = ModelConfig(max_model_len=16) + vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config) + + worker_adapter_manager = WorkerLoRAManager( + vllm_config, "cpu", EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES + ) + worker_adapter_manager.create_lora_manager(dummy_model_gate_up) + + mapping = LoRAMapping([], []) + with pytest.raises(BaseException, match="Model did not validate"): + worker_adapter_manager.set_active_adapters( + [LoRARequest("1", 1, "/not/used")], mapping + ) diff --git a/tests/model_executor/model_loader/test_model_validation.py b/tests/model_executor/model_loader/test_model_validation.py new file mode 100644 index 000000000000..40a4e0756a24 --- /dev/null +++ b/tests/model_executor/model_loader/test_model_validation.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import pytest +from torch import nn + +from vllm.config import DeviceConfig, ModelConfig, VllmConfig +from vllm.config.load import LoadConfig +from vllm.model_executor.model_loader import get_model_loader, register_model_loader +from vllm.model_executor.model_loader.base_loader import BaseModelLoader, DownloadType +from vllm.validation.plugins import ( + ModelType, + ModelValidationPlugin, + ModelValidationPluginRegistry, +) + + +@register_model_loader("custom_load_format") +class CustomModelLoader(BaseModelLoader): + def __init__(self, load_config: LoadConfig) -> None: + super().__init__(load_config) + self.download_type = None + + def download_model(self, model_config: ModelConfig) -> None: + pass + + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: + pass + + def set_download_type(self, download_type: DownloadType) -> None: + """Allow changing download_type""" + self.download_type = download_type + + def get_download_type(self, model_name_or_path: str) -> Optional[DownloadType]: + return self.download_type + + +class MyModelValidator(ModelValidationPlugin): + def model_validation_needed(self, model_type: ModelType, model_path: str) -> bool: + return True + + def validate_model( + self, model_type: ModelType, model_path: str, model: Optional[str] = None + ) -> None: + raise BaseException("Model did not validate") + + +def test_register_model_loader(dist_init): + load_config = LoadConfig(load_format="custom_load_format") + custom_model_loader = get_model_loader(load_config) + assert isinstance(custom_model_loader, CustomModelLoader) + + my_model_validator = MyModelValidator() + ModelValidationPluginRegistry.register_plugin("test", my_model_validator) + + vllm_config = VllmConfig( + model_config=ModelConfig(), + device_config=DeviceConfig("auto"), + load_config=LoadConfig(), + ) + with pytest.raises(RuntimeError): + custom_model_loader.load_model(vllm_config, vllm_config.model_config) + + # have validate_model() called + custom_model_loader.set_download_type(DownloadType.LOCAL_FILE) + + vllm_config = VllmConfig( + model_config=ModelConfig(), + device_config=DeviceConfig("cpu"), + load_config=LoadConfig(), + ) + with pytest.raises(BaseException, match="Model did not validate"): + custom_model_loader.load_model(vllm_config, vllm_config.model_config) diff --git a/tests/v1/engine/test_engine_core_model_validation.py b/tests/v1/engine/test_engine_core_model_validation.py new file mode 100644 index 000000000000..8ce8cf48fd3e --- /dev/null +++ b/tests/v1/engine/test_engine_core_model_validation.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import pytest + +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import set_default_torch_num_threads +from vllm.v1.engine.core import EngineCore +from vllm.v1.executor.abstract import Executor +from vllm.validation.plugins import ( + ModelType, + ModelValidationPlugin, + ModelValidationPluginRegistry, +) + + +class MyModelValidator(ModelValidationPlugin): + def model_validation_needed(self, model_type: ModelType, model_path: str) -> bool: + return True + + def validate_model( + self, model_type: ModelType, model_path: str, model: Optional[str] = None + ) -> None: + raise BaseException("Model did not validate") + + +def test_engine_core_model_validation( + monkeypatch: pytest.MonkeyPatch, dummy_opt_unmodified_path +): + my_model_validator = MyModelValidator() + ModelValidationPluginRegistry.register_plugin("test", my_model_validator) + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + engine_args = EngineArgs(model=dummy_opt_unmodified_path) + vllm_config = engine_args.create_engine_config() + executor_class = Executor.get_class(vllm_config) + + with set_default_torch_num_threads(1) and pytest.raises( + BaseException, match="Model did not validate" + ): + EngineCore( + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=False, + ) diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index d2a58a487a76..8ba615120249 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -249,6 +249,7 @@ async def resolve_lora(self, lora_name: str) -> Union[LoRARequest, ErrorResponse base_model_name = self.model_config.model unique_id = self.lora_id_counter.inc(1) found_adapter = False + reason = "" # Try to resolve using available resolvers for resolver in self.lora_resolvers: @@ -275,13 +276,15 @@ async def resolve_lora(self, lora_name: str) -> Union[LoRARequest, ErrorResponse resolver.__class__.__name__, e, ) + reason = str(e) continue if found_adapter: # An adapter was found, but all attempts to load it failed. return create_error_response( message=( - f"LoRA adapter '{lora_name}' was found but could not be loaded." + f"LoRA adapter '{lora_name}' was found " + f"but could not be loaded: {reason}" ), err_type="BadRequestError", status_code=HTTPStatus.BAD_REQUEST, diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 650e060a5804..8e1debf0c22a 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -6,6 +6,8 @@ import msgspec +from vllm.validation.plugins import ModelType, ModelValidationPluginRegistry + class LoRARequest( msgspec.Struct, @@ -99,3 +101,9 @@ def __hash__(self) -> int: identified by their names across engines. """ return hash(self.lora_name) + + def validate(self) -> None: + """Validate the LoRA adapter given its path.""" + ModelValidationPluginRegistry.validate_model( + ModelType.MODEL_TYPE_LORA, self.lora_path + ) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 3ca819fb732c..42b2681aee7a 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -85,6 +85,7 @@ def create_lora_manager( return lora_manager.model def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: + lora_request.validate() try: supported_lora_modules = self._adapter_manager.supported_lora_modules packed_modules_mapping = self._adapter_manager.packed_modules_mapping diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 6106a1ab8a85..e59df96e1bf2 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import enum from abc import ABC, abstractmethod +from typing import Optional +import huggingface_hub import torch import torch.nn as nn @@ -13,10 +16,18 @@ process_weights_after_loading, set_default_torch_dtype, ) +from vllm.validation.plugins import ModelType, ModelValidationPluginRegistry logger = init_logger(__name__) +class DownloadType(int, enum.Enum): + HUGGINGFACE_HUB = 1 + LOCAL_FILE = 2 + S3 = 3 # not currently supported + UNKNOWN = 4 + + class BaseModelLoader(ABC): """Base class for model loaders.""" @@ -34,6 +45,45 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: inplace weights loading for an already-initialized model""" raise NotImplementedError + def get_download_type(self, model_name_or_path: str) -> Optional[DownloadType]: + """Subclass must override this and return the download type it needs""" + return None + + def download_all_files( + self, model: nn.Module, model_config: ModelConfig, load_config: LoadConfig + ) -> Optional[str]: + """Download all files. Ask the subclass for what type of download + it does; Huggingface is used so often, so download all files here.""" + dt = self.get_download_type(model_config.model) + if dt == DownloadType.HUGGINGFACE_HUB: + return huggingface_hub.snapshot_download( + model_config.model, + allow_patterns=["*"], + cache_dir=self.load_config.download_dir, + revision=model_config.revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + elif dt == DownloadType.LOCAL_FILE: + return model_config.model + return None + + def validate_model( + self, model: nn.Module, model_config: ModelConfig, load_config: LoadConfig + ) -> None: + """If needed, validate the model after downloading _all_ its files.""" + if ModelValidationPluginRegistry.model_validation_needed( + ModelType.MODEL_TYPE_AI_MODEL, model_config.model + ): + folder = self.download_all_files(model, model_config, load_config) + if folder is None: + raise RuntimeError( + "Model validation could not be done due to " + "an unsupported download method." + ) + ModelValidationPluginRegistry.validate_model( + ModelType.MODEL_TYPE_AI_MODEL, folder, model_config.model + ) + def load_model( self, vllm_config: VllmConfig, model_config: ModelConfig ) -> nn.Module: @@ -51,6 +101,7 @@ def load_model( ) logger.debug("Loading weights on %s ...", load_device) + self.validate_model(model, model_config, vllm_config.load_config) # Quantization does not happen in `load_weights` but after it self.load_weights(model, model_config) process_weights_after_loading(model, model_config, target_device) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 8c1ff0300b24..40a8f7403845 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -31,7 +31,7 @@ ReplicatedLinear, RowParallelLinear, ) -from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.base_loader import BaseModelLoader, DownloadType from vllm.model_executor.model_loader.utils import ParamMapping, set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, @@ -820,3 +820,8 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) + + def get_download_type(self, model_name_or_path: str) -> DownloadType: + if os.path.isdir(model_name_or_path): + return DownloadType.LOCAL_FILE + return DownloadType.HUGGINGFACE_HUB diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 206b8244569f..3cfd2458633d 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -14,7 +14,7 @@ from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger -from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.base_loader import BaseModelLoader, DownloadType from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, @@ -319,3 +319,8 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: "Following weights were not initialized from " f"checkpoint: {weights_not_loaded}" ) + + def get_download_type(self, model_name_or_path: str) -> DownloadType: + if os.path.isdir(model_name_or_path): + return DownloadType.LOCAL_FILE + return DownloadType.HUGGINGFACE_HUB diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 93dc754a571c..59c02c6c58b9 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -11,7 +11,7 @@ from vllm.config import ModelConfig, VllmConfig from vllm.config.load import LoadConfig -from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.base_loader import BaseModelLoader, DownloadType from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, @@ -40,15 +40,13 @@ def __init__(self, load_config: LoadConfig): ) def _prepare_weights(self, model_name_or_path: str): - if os.path.isfile(model_name_or_path): + download_type = self.get_download_type(model_name_or_path) + + if download_type == DownloadType.LOCAL_FILE: return model_name_or_path - # for raw HTTPS link - if model_name_or_path.startswith( - ("http://", "https://") - ) and model_name_or_path.endswith(".gguf"): - return hf_hub_download(url=model_name_or_path) - # repo id/filename.gguf - if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"): + elif download_type == DownloadType.HUGGINGFACE_HUB: + if model_name_or_path.startswith(("http://", "https://")): + return hf_hub_download(url=model_name_or_path) repo_id, filename = model_name_or_path.rsplit("/", 1) return hf_hub_download(repo_id=repo_id, filename=filename) else: @@ -170,3 +168,13 @@ def load_model( process_weights_after_loading(model, model_config, target_device) return model + + def get_download_type(self, model_name_or_path: str) -> DownloadType: + if os.path.isfile(model_name_or_path): + return DownloadType.LOCAL_FILE + if model_name_or_path.endswith(".gguf") and ( + model_name_or_path.startswith(("http://", "https://")) + or "/" in model_name_or_path + ): + return DownloadType.HUGGINGFACE_HUB + return DownloadType.UNKNOWN diff --git a/vllm/model_executor/model_loader/runai_streamer_loader.py b/vllm/model_executor/model_loader/runai_streamer_loader.py index 50a92edd1162..9f2f773e3078 100644 --- a/vllm/model_executor/model_loader/runai_streamer_loader.py +++ b/vllm/model_executor/model_loader/runai_streamer_loader.py @@ -11,7 +11,7 @@ from vllm.config import ModelConfig from vllm.config.load import LoadConfig -from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.base_loader import BaseModelLoader, DownloadType from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, @@ -109,3 +109,8 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: model.load_weights( self._get_weights_iterator(model_weights, model_config.revision) ) + + def get_download_type(self, model_name_or_path: str) -> DownloadType: + if os.path.isdir(model_name_or_path): + return DownloadType.LOCAL_FILE + return DownloadType.HUGGINGFACE_HUB diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py index d50a1a8f9dbf..6ea73fc6ac7f 100644 --- a/vllm/model_executor/model_loader/sharded_state_loader.py +++ b/vllm/model_executor/model_loader/sharded_state_loader.py @@ -13,7 +13,7 @@ from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger -from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.base_loader import BaseModelLoader, DownloadType from vllm.model_executor.model_loader.weight_utils import ( download_weights_from_hf, runai_safetensors_weights_iterator, @@ -204,3 +204,8 @@ def save_model( state_dict_part, os.path.join(path, filename), ) + + def get_download_type(self, model_name_or_path: str) -> DownloadType: + if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path): + return DownloadType.LOCAL_FILE + return DownloadType.HUGGINGFACE_HUB diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 4826d7c589a7..013cb8d8c81e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -65,6 +65,7 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.structured_output import StructuredOutputManager +from vllm.validation.plugins import ModelType, ModelValidationPluginRegistry from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -97,6 +98,11 @@ def __init__( vllm_config, ) + if os.path.isdir(vllm_config.model_config.model): + ModelValidationPluginRegistry.validate_model( + ModelType.MODEL_TYPE_AI_MODEL, vllm_config.model_config.model + ) + self.log_stats = log_stats # Setup Model. @@ -115,6 +121,16 @@ def __init__( vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) + if ModelValidationPluginRegistry.model_validation_needed( + ModelType.MODEL_TYPE_AI_MODEL, vllm_config.model_config.model + ): + raise Exception( + "Model validation was requested for " + f"{vllm_config.model_config.model} but was not " + "done since a code path was taken that is not yet " + "instrumented for model validation." + ) + self.structured_output_manager = StructuredOutputManager(vllm_config) # Setup scheduler. @@ -852,8 +868,11 @@ def _handle_client_request( output.result = UtilityResult(result) except BaseException as e: logger.exception("Invocation of %s method failed", method_name) + message = str(e) + if e.__cause__: + message += f" caused by {e.__cause__}" output.failure_message = ( - f"Call to {method_name} method failed: {str(e)}" + f"Call to {method_name} method failed: {message}" ) self.output_queue.put_nowait( (client_idx, EngineCoreOutputs(utility_output=output)) diff --git a/vllm/validation/__init__.py b/vllm/validation/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/validation/plugins.py b/vllm/validation/plugins.py new file mode 100644 index 000000000000..c96df7c78a4e --- /dev/null +++ b/vllm/validation/plugins.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import enum +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Optional + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class ModelType(int, enum.Enum): + MODEL_TYPE_AI_MODEL = 1 + MODEL_TYPE_LORA = 2 + + +class ModelValidationPlugin(ABC): + """Base class for all model validation plugins""" + + @abstractmethod + def model_validation_needed(self, model_type: ModelType, model_path: str) -> bool: + """Have the plugin check whether it already validated the model + at the given model_path.""" + return False + + @abstractmethod + def validate_model( + self, model_type: ModelType, model_path: str, model: Optional[str] = None + ) -> None: + """Validate the model at the given model_path.""" + pass + + +@dataclass +class _ModelValidationPluginRegistry: + plugins: dict[str, ModelValidationPlugin] = field(default_factory=dict) + + def register_plugin(self, plugin_name: str, plugin: ModelValidationPlugin): + """Register a security plugin.""" + if plugin_name in self.plugins: + logger.warning( + "Model validation plugin %s is already registered, and will be " + "overwritten by the new plugin %s.", + plugin_name, + plugin, + ) + + self.plugins[plugin_name] = plugin + + def model_validation_needed(self, model_type: ModelType, model_path: str) -> bool: + """Check whether model validation was requested but was not done, yet. + Returns False in case no model validation was requested or it is already + done. Returns True if model validation was request but not done yet.""" + for plugin in self.plugins.values(): + if plugin.model_validation_needed(model_type, model_path): + return True + return False + + def validate_model( + self, model_type: ModelType, model_path: str, model: Optional[str] = None + ) -> None: + """Have all plugins validate the model at the given path. Any plugin + that cannot validate it will throw an exception.""" + plugins = self.plugins.values() + if plugins: + for plugin in plugins: + plugin.validate_model(model_type, model_path, model) + logger.info("Successfully validated %s", model_path) + + +ModelValidationPluginRegistry = _ModelValidationPluginRegistry()