|
1 |
| -from vllm.model_executor.models.aquila import AquilaForCausalLM |
2 |
| -from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM, |
3 |
| - BaichuanForCausalLM) |
4 |
| -from vllm.model_executor.models.bloom import BloomForCausalLM |
5 |
| -from vllm.model_executor.models.falcon import FalconForCausalLM |
6 |
| -from vllm.model_executor.models.gpt2 import GPT2LMHeadModel |
7 |
| -from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM |
8 |
| -from vllm.model_executor.models.gpt_j import GPTJForCausalLM |
9 |
| -from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM |
10 |
| -from vllm.model_executor.models.internlm import InternLMForCausalLM |
11 |
| -from vllm.model_executor.models.llama import LlamaForCausalLM |
12 |
| -from vllm.model_executor.models.mistral import MistralForCausalLM |
13 |
| -from vllm.model_executor.models.mixtral import MixtralForCausalLM |
14 |
| -from vllm.model_executor.models.mpt import MPTForCausalLM |
15 |
| -from vllm.model_executor.models.opt import OPTForCausalLM |
16 |
| -from vllm.model_executor.models.phi_1_5 import PhiForCausalLM |
17 |
| -from vllm.model_executor.models.qwen import QWenLMHeadModel |
18 |
| -from vllm.model_executor.models.chatglm import ChatGLMForCausalLM |
19 |
| -from vllm.model_executor.models.yi import YiForCausalLM |
| 1 | +import importlib |
| 2 | +from typing import List, Optional, Type |
| 3 | + |
| 4 | +import torch.nn as nn |
| 5 | + |
| 6 | +from vllm.logger import init_logger |
| 7 | +from vllm.utils import is_hip |
| 8 | + |
| 9 | +logger = init_logger(__name__) |
| 10 | + |
| 11 | +# Architecture -> (module, class). |
| 12 | +_MODELS = { |
| 13 | + "AquilaModel": ("aquila", "AquilaForCausalLM"), |
| 14 | + "AquilaForCausalLM": ("aquila", "AquilaForCausalLM"), # AquilaChat2 |
| 15 | + "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b |
| 16 | + "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b |
| 17 | + "BloomForCausalLM": ("bloom", "BloomForCausalLM"), |
| 18 | + "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), |
| 19 | + "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), |
| 20 | + "FalconForCausalLM": ("falcon", "FalconForCausalLM"), |
| 21 | + "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), |
| 22 | + "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), |
| 23 | + "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), |
| 24 | + "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), |
| 25 | + "InternLMForCausalLM": ("internlm", "InternLMForCausalLM"), |
| 26 | + "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), |
| 27 | + # For decapoda-research/llama-* |
| 28 | + "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), |
| 29 | + "MistralForCausalLM": ("mistral", "MistralForCausalLM"), |
| 30 | + "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), |
| 31 | + # transformers's mpt class has lower case |
| 32 | + "MptForCausalLM": ("mpt", "MPTForCausalLM"), |
| 33 | + "MPTForCausalLM": ("mpt", "MPTForCausalLM"), |
| 34 | + "OPTForCausalLM": ("opt", "OPTForCausalLM"), |
| 35 | + "PhiForCausalLM": ("phi_1_5", "PhiForCausalLM"), |
| 36 | + "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), |
| 37 | + "RWForCausalLM": ("falcon", "FalconForCausalLM"), |
| 38 | + "YiForCausalLM": ("yi", "YiForCausalLM"), |
| 39 | +} |
| 40 | + |
| 41 | +# Models not supported by ROCm. |
| 42 | +_ROCM_UNSUPPORTED_MODELS = ["MixtralForCausalLM"] |
| 43 | + |
| 44 | +# Models partially supported by ROCm. |
| 45 | +# Architecture -> Reason. |
| 46 | +_ROCM_PARTIALLY_SUPPORTED_MODELS = { |
| 47 | + "MistralForCausalLM": |
| 48 | + "Sliding window attention is not yet supported in ROCm's flash attention", |
| 49 | +} |
| 50 | + |
| 51 | + |
| 52 | +class ModelRegistry: |
| 53 | + |
| 54 | + @staticmethod |
| 55 | + def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: |
| 56 | + if model_arch not in _MODELS: |
| 57 | + return None |
| 58 | + if is_hip(): |
| 59 | + if model_arch in _ROCM_UNSUPPORTED_MODELS: |
| 60 | + raise ValueError( |
| 61 | + f"Model architecture {model_arch} is not supported by " |
| 62 | + "ROCm for now.") |
| 63 | + if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: |
| 64 | + logger.warning( |
| 65 | + f"Model architecture {model_arch} is partially supported " |
| 66 | + "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) |
| 67 | + |
| 68 | + module_name, model_cls_name = _MODELS[model_arch] |
| 69 | + module = importlib.import_module( |
| 70 | + f"vllm.model_executor.models.{module_name}") |
| 71 | + return getattr(module, model_cls_name, None) |
| 72 | + |
| 73 | + @staticmethod |
| 74 | + def get_supported_archs() -> List[str]: |
| 75 | + return list(_MODELS.keys()) |
| 76 | + |
20 | 77 |
|
21 | 78 | __all__ = [
|
22 |
| - "AquilaForCausalLM", |
23 |
| - "BaiChuanForCausalLM", |
24 |
| - "BaichuanForCausalLM", |
25 |
| - "BloomForCausalLM", |
26 |
| - "ChatGLMForCausalLM", |
27 |
| - "FalconForCausalLM", |
28 |
| - "GPT2LMHeadModel", |
29 |
| - "GPTBigCodeForCausalLM", |
30 |
| - "GPTJForCausalLM", |
31 |
| - "GPTNeoXForCausalLM", |
32 |
| - "InternLMForCausalLM", |
33 |
| - "LlamaForCausalLM", |
34 |
| - "MPTForCausalLM", |
35 |
| - "OPTForCausalLM", |
36 |
| - "PhiForCausalLM", |
37 |
| - "QWenLMHeadModel", |
38 |
| - "MistralForCausalLM", |
39 |
| - "MixtralForCausalLM", |
40 |
| - "YiForCausalLM", |
| 79 | + "ModelRegistry", |
41 | 80 | ]
|
0 commit comments