Skip to content

Commit e73621f

Browse files
committed
Enable model validation on downloaded models
Implement a method 'validate' in the BaseModelLoader that first checks whether any plugin requests to validate the given model and then possibly downloads all the model files, including the signature. For this, query the subclass of BaseModelLoader for its download type. Support validation of local models and those downloaded from Huggingface Hub. Add a test case. Signed-off-by: Stefan Berger <[email protected]>
1 parent 6d3aae5 commit e73621f

File tree

7 files changed

+167
-13
lines changed

7 files changed

+167
-13
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from typing import Optional
5+
6+
import pytest
7+
from torch import nn
8+
9+
from vllm.config import DeviceConfig, ModelConfig, VllmConfig
10+
from vllm.config.load import LoadConfig
11+
from vllm.model_executor.model_loader import get_model_loader, register_model_loader
12+
from vllm.model_executor.model_loader.base_loader import BaseModelLoader, DownloadType
13+
from vllm.validation.plugins import (
14+
ModelType,
15+
ModelValidationPlugin,
16+
ModelValidationPluginRegistry,
17+
)
18+
19+
20+
@register_model_loader("custom_load_format")
21+
class CustomModelLoader(BaseModelLoader):
22+
def __init__(self, load_config: LoadConfig) -> None:
23+
super().__init__(load_config)
24+
self.download_type = None
25+
26+
def download_model(self, model_config: ModelConfig) -> None:
27+
pass
28+
29+
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
30+
pass
31+
32+
def set_download_type(self, download_type: DownloadType) -> None:
33+
"""Allow changing download_type"""
34+
self.download_type = download_type
35+
36+
def get_download_type(self, model_name_or_path: str) -> Optional[DownloadType]:
37+
return self.download_type
38+
39+
40+
class MyModelValidator(ModelValidationPlugin):
41+
def model_validation_needed(self, model_type: ModelType, model_path: str) -> bool:
42+
return True
43+
44+
def validate_model(
45+
self, model_type: ModelType, model_path: str, model: Optional[str] = None
46+
) -> None:
47+
raise BaseException("Model did not validate")
48+
49+
50+
def test_register_model_loader(dist_init):
51+
load_config = LoadConfig(load_format="custom_load_format")
52+
custom_model_loader = get_model_loader(load_config)
53+
assert isinstance(custom_model_loader, CustomModelLoader)
54+
55+
my_model_validator = MyModelValidator()
56+
ModelValidationPluginRegistry.register_plugin("test", my_model_validator)
57+
58+
vllm_config = VllmConfig(
59+
model_config=ModelConfig(),
60+
device_config=DeviceConfig("auto"),
61+
load_config=LoadConfig(),
62+
)
63+
with pytest.raises(RuntimeError):
64+
custom_model_loader.load_model(vllm_config, vllm_config.model_config)
65+
66+
# have validate_model() called
67+
custom_model_loader.set_download_type(DownloadType.LOCAL_FILE)
68+
69+
vllm_config = VllmConfig(
70+
model_config=ModelConfig(),
71+
device_config=DeviceConfig("cpu"),
72+
load_config=LoadConfig(),
73+
)
74+
with pytest.raises(BaseException, match="Model did not validate"):
75+
custom_model_loader.load_model(vllm_config, vllm_config.model_config)

vllm/model_executor/model_loader/base_loader.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import enum
34
from abc import ABC, abstractmethod
5+
from typing import Optional
46

7+
import huggingface_hub
58
import torch
69
import torch.nn as nn
710

@@ -13,10 +16,18 @@
1316
process_weights_after_loading,
1417
set_default_torch_dtype,
1518
)
19+
from vllm.validation.plugins import ModelType, ModelValidationPluginRegistry
1620

1721
logger = init_logger(__name__)
1822

1923

24+
class DownloadType(int, enum.Enum):
25+
HUGGINGFACE_HUB = 1
26+
LOCAL_FILE = 2
27+
S3 = 3 # not currently supported
28+
UNKNOWN = 4
29+
30+
2031
class BaseModelLoader(ABC):
2132
"""Base class for model loaders."""
2233

@@ -34,6 +45,45 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
3445
inplace weights loading for an already-initialized model"""
3546
raise NotImplementedError
3647

48+
def get_download_type(self, model_name_or_path: str) -> Optional[DownloadType]:
49+
"""Subclass must override this and return the download type it needs"""
50+
return None
51+
52+
def download_all_files(
53+
self, model: nn.Module, model_config: ModelConfig, load_config: LoadConfig
54+
) -> Optional[str]:
55+
"""Download all files. Ask the subclass for what type of download
56+
it does; Huggingface is used so often, so download all files here."""
57+
dt = self.get_download_type(model_config.model)
58+
if dt == DownloadType.HUGGINGFACE_HUB:
59+
return huggingface_hub.snapshot_download(
60+
model_config.model,
61+
allow_patterns=["*"],
62+
cache_dir=self.load_config.download_dir,
63+
revision=model_config.revision,
64+
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
65+
)
66+
elif dt == DownloadType.LOCAL_FILE:
67+
return model_config.model
68+
return None
69+
70+
def validate_model(
71+
self, model: nn.Module, model_config: ModelConfig, load_config: LoadConfig
72+
) -> None:
73+
"""If needed, validate the model after downloading _all_ its files."""
74+
if ModelValidationPluginRegistry.model_validation_needed(
75+
ModelType.MODEL_TYPE_AI_MODEL, model_config.model
76+
):
77+
folder = self.download_all_files(model, model_config, load_config)
78+
if folder is None:
79+
raise RuntimeError(
80+
"Model validation could not be done due to "
81+
"an unsupported download method."
82+
)
83+
ModelValidationPluginRegistry.validate_model(
84+
ModelType.MODEL_TYPE_AI_MODEL, folder, model_config.model
85+
)
86+
3787
def load_model(
3888
self, vllm_config: VllmConfig, model_config: ModelConfig
3989
) -> nn.Module:
@@ -51,6 +101,7 @@ def load_model(
51101
)
52102

53103
logger.debug("Loading weights on %s ...", load_device)
104+
self.validate_model(model, model_config, vllm_config.load_config)
54105
# Quantization does not happen in `load_weights` but after it
55106
self.load_weights(model, model_config)
56107
process_weights_after_loading(model, model_config, target_device)

vllm/model_executor/model_loader/bitsandbytes_loader.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
ReplicatedLinear,
3232
RowParallelLinear,
3333
)
34-
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
34+
from vllm.model_executor.model_loader.base_loader import BaseModelLoader, DownloadType
3535
from vllm.model_executor.model_loader.utils import ParamMapping, set_default_torch_dtype
3636
from vllm.model_executor.model_loader.weight_utils import (
3737
download_safetensors_index_file_from_hf,
@@ -820,3 +820,8 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
820820

821821
def download_model(self, model_config: ModelConfig) -> None:
822822
self._prepare_weights(model_config.model, model_config.revision)
823+
824+
def get_download_type(self, model_name_or_path: str) -> DownloadType:
825+
if os.path.isdir(model_name_or_path):
826+
return DownloadType.LOCAL_FILE
827+
return DownloadType.HUGGINGFACE_HUB

vllm/model_executor/model_loader/default_loader.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.config import ModelConfig
1515
from vllm.config.load import LoadConfig
1616
from vllm.logger import init_logger
17-
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
17+
from vllm.model_executor.model_loader.base_loader import BaseModelLoader, DownloadType
1818
from vllm.model_executor.model_loader.weight_utils import (
1919
download_safetensors_index_file_from_hf,
2020
download_weights_from_hf,
@@ -319,3 +319,8 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
319319
"Following weights were not initialized from "
320320
f"checkpoint: {weights_not_loaded}"
321321
)
322+
323+
def get_download_type(self, model_name_or_path: str) -> DownloadType:
324+
if os.path.isdir(model_name_or_path):
325+
return DownloadType.LOCAL_FILE
326+
return DownloadType.HUGGINGFACE_HUB

vllm/model_executor/model_loader/gguf_loader.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from vllm.config import ModelConfig, VllmConfig
1313
from vllm.config.load import LoadConfig
14-
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
14+
from vllm.model_executor.model_loader.base_loader import BaseModelLoader, DownloadType
1515
from vllm.model_executor.model_loader.utils import (
1616
initialize_model,
1717
process_weights_after_loading,
@@ -40,15 +40,13 @@ def __init__(self, load_config: LoadConfig):
4040
)
4141

4242
def _prepare_weights(self, model_name_or_path: str):
43-
if os.path.isfile(model_name_or_path):
43+
download_type = self.get_download_type(model_name_or_path)
44+
45+
if download_type == DownloadType.LOCAL_FILE:
4446
return model_name_or_path
45-
# for raw HTTPS link
46-
if model_name_or_path.startswith(
47-
("http://", "https://")
48-
) and model_name_or_path.endswith(".gguf"):
49-
return hf_hub_download(url=model_name_or_path)
50-
# repo id/filename.gguf
51-
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
47+
elif download_type == DownloadType.HUGGINGFACE_HUB:
48+
if model_name_or_path.startswith(("http://", "https://")):
49+
return hf_hub_download(url=model_name_or_path)
5250
repo_id, filename = model_name_or_path.rsplit("/", 1)
5351
return hf_hub_download(repo_id=repo_id, filename=filename)
5452
else:
@@ -170,3 +168,13 @@ def load_model(
170168

171169
process_weights_after_loading(model, model_config, target_device)
172170
return model
171+
172+
def get_download_type(self, model_name_or_path: str) -> DownloadType:
173+
if os.path.isfile(model_name_or_path):
174+
return DownloadType.LOCAL_FILE
175+
if model_name_or_path.endswith(".gguf") and (
176+
model_name_or_path.startswith(("http://", "https://"))
177+
or "/" in model_name_or_path
178+
):
179+
return DownloadType.HUGGINGFACE_HUB
180+
return DownloadType.UNKNOWN

vllm/model_executor/model_loader/runai_streamer_loader.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from vllm.config import ModelConfig
1313
from vllm.config.load import LoadConfig
14-
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
14+
from vllm.model_executor.model_loader.base_loader import BaseModelLoader, DownloadType
1515
from vllm.model_executor.model_loader.weight_utils import (
1616
download_safetensors_index_file_from_hf,
1717
download_weights_from_hf,
@@ -109,3 +109,8 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
109109
model.load_weights(
110110
self._get_weights_iterator(model_weights, model_config.revision)
111111
)
112+
113+
def get_download_type(self, model_name_or_path: str) -> DownloadType:
114+
if os.path.isdir(model_name_or_path):
115+
return DownloadType.LOCAL_FILE
116+
return DownloadType.HUGGINGFACE_HUB

vllm/model_executor/model_loader/sharded_state_loader.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from vllm.config import ModelConfig
1414
from vllm.config.load import LoadConfig
1515
from vllm.logger import init_logger
16-
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
16+
from vllm.model_executor.model_loader.base_loader import BaseModelLoader, DownloadType
1717
from vllm.model_executor.model_loader.weight_utils import (
1818
download_weights_from_hf,
1919
runai_safetensors_weights_iterator,
@@ -204,3 +204,8 @@ def save_model(
204204
state_dict_part,
205205
os.path.join(path, filename),
206206
)
207+
208+
def get_download_type(self, model_name_or_path: str) -> DownloadType:
209+
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
210+
return DownloadType.LOCAL_FILE
211+
return DownloadType.HUGGINGFACE_HUB

0 commit comments

Comments
 (0)