Skip to content

Commit 5a9007b

Browse files
committed
Add model validation plugin and registry and validate local models
Add a model validation plugin registry where classes implementing the ModelValidationPlugin interface can be registered. Enable the validating on local models that have already been downloaded by the user. Add a test case with an already downloaded model whose config.json is unmodified so that a ModelConfig can be created from it. Signed-off-by: Stefan Berger <[email protected]>
1 parent b8f603c commit 5a9007b

File tree

5 files changed

+152
-0
lines changed

5 files changed

+152
-0
lines changed

tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,7 @@ def num_gpus_available():
10621062

10631063
temp_dir = tempfile.gettempdir()
10641064
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
1065+
_dummy_opt_unmodified_path = os.path.join(temp_dir, "dummy_opt_unmodified")
10651066
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
10661067
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
10671068

@@ -1084,6 +1085,19 @@ def dummy_opt_path():
10841085
return _dummy_opt_path
10851086

10861087

1088+
@pytest.fixture
1089+
def dummy_opt_unmodified_path():
1090+
json_path = os.path.join(_dummy_opt_unmodified_path, "config.json")
1091+
if not os.path.exists(_dummy_opt_unmodified_path):
1092+
snapshot_download(
1093+
repo_id="facebook/opt-125m",
1094+
local_dir=_dummy_opt_unmodified_path,
1095+
ignore_patterns=["*.bin", "*.bin.index.json", "*.pt", "*.h5", "*.msgpack"],
1096+
)
1097+
assert os.path.exists(json_path)
1098+
return _dummy_opt_unmodified_path
1099+
1100+
10871101
@pytest.fixture
10881102
def dummy_llava_path():
10891103
json_path = os.path.join(_dummy_llava_path, "config.json")
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from typing import Optional
5+
6+
import pytest
7+
8+
from vllm.engine.arg_utils import EngineArgs
9+
from vllm.utils import set_default_torch_num_threads
10+
from vllm.v1.engine.core import EngineCore
11+
from vllm.v1.executor.abstract import Executor
12+
from vllm.validation.plugins import (
13+
ModelType,
14+
ModelValidationPlugin,
15+
ModelValidationPluginRegistry,
16+
)
17+
18+
19+
class MyModelValidator(ModelValidationPlugin):
20+
def model_validation_needed(self, model_type: ModelType, model_path: str) -> bool:
21+
return True
22+
23+
def validate_model(
24+
self, model_type: ModelType, model_path: str, model: Optional[str] = None
25+
) -> None:
26+
raise BaseException("Model did not validate")
27+
28+
29+
def test_engine_core_model_validation(
30+
monkeypatch: pytest.MonkeyPatch, dummy_opt_unmodified_path
31+
):
32+
my_model_validator = MyModelValidator()
33+
ModelValidationPluginRegistry.register_plugin("test", my_model_validator)
34+
35+
with monkeypatch.context() as m:
36+
m.setenv("VLLM_USE_V1", "1")
37+
38+
engine_args = EngineArgs(model=dummy_opt_unmodified_path)
39+
vllm_config = engine_args.create_engine_config()
40+
executor_class = Executor.get_class(vllm_config)
41+
42+
with set_default_torch_num_threads(1) and pytest.raises(
43+
BaseException, match="Model did not validate"
44+
):
45+
EngineCore(
46+
vllm_config=vllm_config,
47+
executor_class=executor_class,
48+
log_stats=False,
49+
)

vllm/v1/engine/core.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from vllm.v1.request import Request, RequestStatus
6666
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
6767
from vllm.v1.structured_output import StructuredOutputManager
68+
from vllm.validation.plugins import ModelType, ModelValidationPluginRegistry
6869
from vllm.version import __version__ as VLLM_VERSION
6970

7071
logger = init_logger(__name__)
@@ -97,6 +98,11 @@ def __init__(
9798
vllm_config,
9899
)
99100

101+
if os.path.isdir(vllm_config.model_config.model):
102+
ModelValidationPluginRegistry.validate_model(
103+
ModelType.MODEL_TYPE_AI_MODEL, vllm_config.model_config.model
104+
)
105+
100106
self.log_stats = log_stats
101107

102108
# Setup Model.
@@ -115,6 +121,16 @@ def __init__(
115121
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
116122
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
117123

124+
if ModelValidationPluginRegistry.model_validation_needed(
125+
ModelType.MODEL_TYPE_AI_MODEL, vllm_config.model_config.model
126+
):
127+
raise Exception(
128+
"Model validation was requested for "
129+
f"{vllm_config.model_config.model} but was not "
130+
"done since a code path was taken that is not yet "
131+
"instrumented for model validation."
132+
)
133+
118134
self.structured_output_manager = StructuredOutputManager(vllm_config)
119135

120136
# Setup scheduler.

vllm/validation/__init__.py

Whitespace-only changes.

vllm/validation/plugins.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import enum
5+
from abc import ABC, abstractmethod
6+
from dataclasses import dataclass, field
7+
from typing import Optional
8+
9+
from vllm.logger import init_logger
10+
11+
logger = init_logger(__name__)
12+
13+
14+
class ModelType(int, enum.Enum):
15+
MODEL_TYPE_AI_MODEL = 1
16+
MODEL_TYPE_LORA = 2
17+
18+
19+
class ModelValidationPlugin(ABC):
20+
"""Base class for all model validation plugins"""
21+
22+
@abstractmethod
23+
def model_validation_needed(self, model_type: ModelType, model_path: str) -> bool:
24+
"""Have the plugin check whether it already validated the model
25+
at the given model_path."""
26+
return False
27+
28+
@abstractmethod
29+
def validate_model(
30+
self, model_type: ModelType, model_path: str, model: Optional[str] = None
31+
) -> None:
32+
"""Validate the model at the given model_path."""
33+
pass
34+
35+
36+
@dataclass
37+
class _ModelValidationPluginRegistry:
38+
plugins: dict[str, ModelValidationPlugin] = field(default_factory=dict)
39+
40+
def register_plugin(self, plugin_name: str, plugin: ModelValidationPlugin):
41+
"""Register a security plugin."""
42+
if plugin_name in self.plugins:
43+
logger.warning(
44+
"Model validation plugin %s is already registered, and will be "
45+
"overwritten by the new plugin %s.",
46+
plugin_name,
47+
plugin,
48+
)
49+
50+
self.plugins[plugin_name] = plugin
51+
52+
def model_validation_needed(self, model_type: ModelType, model_path: str) -> bool:
53+
"""Check whether model validation was requested but was not done, yet.
54+
Returns False in case no model validation was requested or it is already
55+
done. Returns True if model validation was request but not done yet."""
56+
for plugin in self.plugins.values():
57+
if plugin.model_validation_needed(model_type, model_path):
58+
return True
59+
return False
60+
61+
def validate_model(
62+
self, model_type: ModelType, model_path: str, model: Optional[str] = None
63+
) -> None:
64+
"""Have all plugins validate the model at the given path. Any plugin
65+
that cannot validate it will throw an exception."""
66+
plugins = self.plugins.values()
67+
if plugins:
68+
for plugin in plugins:
69+
plugin.validate_model(model_type, model_path, model)
70+
logger.info("Successfully validated %s", model_path)
71+
72+
73+
ModelValidationPluginRegistry = _ModelValidationPluginRegistry()

0 commit comments

Comments
 (0)