Skip to content

Commit 518369d

Browse files
authored
Implement lazy model loader (#2044)
1 parent 30bad5c commit 518369d

File tree

3 files changed

+89
-101
lines changed

3 files changed

+89
-101
lines changed

vllm/model_executor/model_loader.py

Lines changed: 5 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -7,54 +7,9 @@
77
from transformers import PretrainedConfig
88

99
from vllm.config import ModelConfig
10-
from vllm.model_executor.models import *
10+
from vllm.model_executor.models import ModelRegistry
1111
from vllm.model_executor.weight_utils import (get_quant_config,
1212
initialize_dummy_weights)
13-
from vllm.utils import is_hip
14-
from vllm.logger import init_logger
15-
16-
logger = init_logger(__name__)
17-
18-
# TODO(woosuk): Lazy-load the model classes.
19-
_MODEL_REGISTRY = {
20-
"AquilaModel": AquilaForCausalLM,
21-
"AquilaForCausalLM": AquilaForCausalLM, # AquilaChat2
22-
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
23-
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
24-
"BloomForCausalLM": BloomForCausalLM,
25-
"ChatGLMModel": ChatGLMForCausalLM,
26-
"ChatGLMForConditionalGeneration": ChatGLMForCausalLM,
27-
"FalconForCausalLM": FalconForCausalLM,
28-
"GPT2LMHeadModel": GPT2LMHeadModel,
29-
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
30-
"GPTJForCausalLM": GPTJForCausalLM,
31-
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
32-
"InternLMForCausalLM": InternLMForCausalLM,
33-
"LlamaForCausalLM": LlamaForCausalLM,
34-
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
35-
"MistralForCausalLM": MistralForCausalLM,
36-
"MixtralForCausalLM": MixtralForCausalLM,
37-
# transformers's mpt class has lower case
38-
"MptForCausalLM": MPTForCausalLM,
39-
"MPTForCausalLM": MPTForCausalLM,
40-
"OPTForCausalLM": OPTForCausalLM,
41-
"PhiForCausalLM": PhiForCausalLM,
42-
"QWenLMHeadModel": QWenLMHeadModel,
43-
"RWForCausalLM": FalconForCausalLM,
44-
"YiForCausalLM": YiForCausalLM,
45-
}
46-
47-
# Models to be disabled in ROCm
48-
_ROCM_UNSUPPORTED_MODELS = []
49-
if is_hip():
50-
for rocm_model in _ROCM_UNSUPPORTED_MODELS:
51-
del _MODEL_REGISTRY[rocm_model]
52-
53-
# Models partially supported in ROCm
54-
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
55-
"MistralForCausalLM":
56-
"Sliding window attention is not supported in ROCm's flash attention",
57-
}
5813

5914

6015
@contextlib.contextmanager
@@ -69,19 +24,12 @@ def _set_default_torch_dtype(dtype: torch.dtype):
6924
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
7025
architectures = getattr(config, "architectures", [])
7126
for arch in architectures:
72-
if arch in _MODEL_REGISTRY:
73-
if is_hip() and arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
74-
logger.warning(
75-
f"{arch} is not fully supported in ROCm. Reason: "
76-
f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[arch]}")
77-
return _MODEL_REGISTRY[arch]
78-
elif arch in _ROCM_UNSUPPORTED_MODELS:
79-
raise ValueError(
80-
f"Model architecture {arch} is not supported by ROCm for now. \n"
81-
f"Supported architectures {list(_MODEL_REGISTRY.keys())}")
27+
model_cls = ModelRegistry.load_model_cls(arch)
28+
if model_cls is not None:
29+
return model_cls
8230
raise ValueError(
8331
f"Model architectures {architectures} are not supported for now. "
84-
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
32+
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
8533

8634

8735
def get_model(model_config: ModelConfig) -> nn.Module:
Lines changed: 77 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,80 @@
1-
from vllm.model_executor.models.aquila import AquilaForCausalLM
2-
from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM,
3-
BaichuanForCausalLM)
4-
from vllm.model_executor.models.bloom import BloomForCausalLM
5-
from vllm.model_executor.models.falcon import FalconForCausalLM
6-
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
7-
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
8-
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
9-
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
10-
from vllm.model_executor.models.internlm import InternLMForCausalLM
11-
from vllm.model_executor.models.llama import LlamaForCausalLM
12-
from vllm.model_executor.models.mistral import MistralForCausalLM
13-
from vllm.model_executor.models.mixtral import MixtralForCausalLM
14-
from vllm.model_executor.models.mpt import MPTForCausalLM
15-
from vllm.model_executor.models.opt import OPTForCausalLM
16-
from vllm.model_executor.models.phi_1_5 import PhiForCausalLM
17-
from vllm.model_executor.models.qwen import QWenLMHeadModel
18-
from vllm.model_executor.models.chatglm import ChatGLMForCausalLM
19-
from vllm.model_executor.models.yi import YiForCausalLM
1+
import importlib
2+
from typing import List, Optional, Type
3+
4+
import torch.nn as nn
5+
6+
from vllm.logger import init_logger
7+
from vllm.utils import is_hip
8+
9+
logger = init_logger(__name__)
10+
11+
# Architecture -> (module, class).
12+
_MODELS = {
13+
"AquilaModel": ("aquila", "AquilaForCausalLM"),
14+
"AquilaForCausalLM": ("aquila", "AquilaForCausalLM"), # AquilaChat2
15+
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
16+
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
17+
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
18+
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
19+
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
20+
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
21+
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
22+
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
23+
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
24+
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
25+
"InternLMForCausalLM": ("internlm", "InternLMForCausalLM"),
26+
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
27+
# For decapoda-research/llama-*
28+
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
29+
"MistralForCausalLM": ("mistral", "MistralForCausalLM"),
30+
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
31+
# transformers's mpt class has lower case
32+
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
33+
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
34+
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
35+
"PhiForCausalLM": ("phi_1_5", "PhiForCausalLM"),
36+
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
37+
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
38+
"YiForCausalLM": ("yi", "YiForCausalLM"),
39+
}
40+
41+
# Models not supported by ROCm.
42+
_ROCM_UNSUPPORTED_MODELS = ["MixtralForCausalLM"]
43+
44+
# Models partially supported by ROCm.
45+
# Architecture -> Reason.
46+
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
47+
"MistralForCausalLM":
48+
"Sliding window attention is not yet supported in ROCm's flash attention",
49+
}
50+
51+
52+
class ModelRegistry:
53+
54+
@staticmethod
55+
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
56+
if model_arch not in _MODELS:
57+
return None
58+
if is_hip():
59+
if model_arch in _ROCM_UNSUPPORTED_MODELS:
60+
raise ValueError(
61+
f"Model architecture {model_arch} is not supported by "
62+
"ROCm for now.")
63+
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
64+
logger.warning(
65+
f"Model architecture {model_arch} is partially supported "
66+
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
67+
68+
module_name, model_cls_name = _MODELS[model_arch]
69+
module = importlib.import_module(
70+
f"vllm.model_executor.models.{module_name}")
71+
return getattr(module, model_cls_name, None)
72+
73+
@staticmethod
74+
def get_supported_archs() -> List[str]:
75+
return list(_MODELS.keys())
76+
2077

2178
__all__ = [
22-
"AquilaForCausalLM",
23-
"BaiChuanForCausalLM",
24-
"BaichuanForCausalLM",
25-
"BloomForCausalLM",
26-
"ChatGLMForCausalLM",
27-
"FalconForCausalLM",
28-
"GPT2LMHeadModel",
29-
"GPTBigCodeForCausalLM",
30-
"GPTJForCausalLM",
31-
"GPTNeoXForCausalLM",
32-
"InternLMForCausalLM",
33-
"LlamaForCausalLM",
34-
"MPTForCausalLM",
35-
"OPTForCausalLM",
36-
"PhiForCausalLM",
37-
"QWenLMHeadModel",
38-
"MistralForCausalLM",
39-
"MixtralForCausalLM",
40-
"YiForCausalLM",
79+
"ModelRegistry",
4180
]

vllm/model_executor/models/mixtral.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,15 @@
3333

3434
try:
3535
import megablocks.ops as ops
36-
except ImportError:
37-
print(
38-
"MegaBlocks not found. Please install it by `pip install megablocks`.")
36+
except ImportError as e:
37+
raise ImportError("MegaBlocks not found. "
38+
"Please install it by `pip install megablocks`.") from e
3939
try:
4040
import stk
41-
except ImportError:
42-
print(
43-
"STK not found: please see https://github.com/stanford-futuredata/stk")
41+
except ImportError as e:
42+
raise ImportError(
43+
"STK not found. "
44+
"Please install it by `pip install stanford-stk`.") from e
4445

4546
from vllm.model_executor.input_metadata import InputMetadata
4647
from vllm.model_executor.layers.attention import PagedAttention

0 commit comments

Comments
 (0)