1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import enum
3
4
from abc import ABC , abstractmethod
5
+ from typing import Optional
4
6
7
+ import huggingface_hub
5
8
import torch
6
9
import torch .nn as nn
7
10
13
16
process_weights_after_loading ,
14
17
set_default_torch_dtype ,
15
18
)
19
+ from vllm .validation .plugins import ModelType , ModelValidationPluginRegistry
16
20
17
21
logger = init_logger (__name__ )
18
22
19
23
24
+ class DownloadType (int , enum .Enum ):
25
+ HUGGINGFACE_HUB = 1
26
+ LOCAL_FILE = 2
27
+ S3 = 3 # not currently supported
28
+ UNKNOWN = 4
29
+
30
+
20
31
class BaseModelLoader (ABC ):
21
32
"""Base class for model loaders."""
22
33
@@ -34,6 +45,45 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
34
45
inplace weights loading for an already-initialized model"""
35
46
raise NotImplementedError
36
47
48
+ def get_download_type (self , model_name_or_path : str ) -> Optional [DownloadType ]:
49
+ """Subclass must override this and return the download type it needs"""
50
+ return None
51
+
52
+ def download_all_files (
53
+ self , model : nn .Module , model_config : ModelConfig , load_config : LoadConfig
54
+ ) -> Optional [str ]:
55
+ """Download all files. Ask the subclass for what type of download
56
+ it does; Huggingface is used so often, so download all files here."""
57
+ dt = self .get_download_type (model_config .model )
58
+ if dt == DownloadType .HUGGINGFACE_HUB :
59
+ return huggingface_hub .snapshot_download (
60
+ model_config .model ,
61
+ allow_patterns = ["*" ],
62
+ cache_dir = self .load_config .download_dir ,
63
+ revision = model_config .revision ,
64
+ local_files_only = huggingface_hub .constants .HF_HUB_OFFLINE ,
65
+ )
66
+ elif dt == DownloadType .LOCAL_FILE :
67
+ return model_config .model
68
+ return None
69
+
70
+ def validate_model (
71
+ self , model : nn .Module , model_config : ModelConfig , load_config : LoadConfig
72
+ ) -> None :
73
+ """If needed, validate the model after downloading _all_ its files."""
74
+ if ModelValidationPluginRegistry .model_validation_needed (
75
+ ModelType .MODEL_TYPE_AI_MODEL , model_config .model
76
+ ):
77
+ folder = self .download_all_files (model , model_config , load_config )
78
+ if folder is None :
79
+ raise RuntimeError (
80
+ "Model validation could not be done due to "
81
+ "an unsupported download method."
82
+ )
83
+ ModelValidationPluginRegistry .validate_model (
84
+ ModelType .MODEL_TYPE_AI_MODEL , folder , model_config .model
85
+ )
86
+
37
87
def load_model (
38
88
self , vllm_config : VllmConfig , model_config : ModelConfig
39
89
) -> nn .Module :
@@ -51,6 +101,7 @@ def load_model(
51
101
)
52
102
53
103
logger .debug ("Loading weights on %s ..." , load_device )
104
+ self .validate_model (model , model_config , vllm_config .load_config )
54
105
# Quantization does not happen in `load_weights` but after it
55
106
self .load_weights (model , model_config )
56
107
process_weights_after_loading (model , model_config , target_device )
0 commit comments