|
| 1 | +from dataclasses import dataclass, field |
| 2 | +from typing import AbstractSet, Mapping, Optional |
| 3 | + |
| 4 | + |
| 5 | +@dataclass(frozen=True) |
| 6 | +class _HfExamplesInfo: |
| 7 | + default: str |
| 8 | + """The default model to use for testing this architecture.""" |
| 9 | + |
| 10 | + extras: Mapping[str, str] = field(default_factory=dict) |
| 11 | + """Extra models to use for testing this architecture.""" |
| 12 | + |
| 13 | + tokenizer: Optional[str] = None |
| 14 | + """Set the tokenizer to load for this architecture.""" |
| 15 | + |
| 16 | + tokenizer_mode: str = "auto" |
| 17 | + """Set the tokenizer type for this architecture.""" |
| 18 | + |
| 19 | + speculative_model: Optional[str] = None |
| 20 | + """ |
| 21 | + The default model to use for testing this architecture, which is only used |
| 22 | + for speculative decoding. |
| 23 | + """ |
| 24 | + |
| 25 | + is_available_online: bool = True |
| 26 | + """ |
| 27 | + Set this to ``False`` if the name of this architecture no longer exists on |
| 28 | + the HF repo. To maintain backwards compatibility, we have not removed them |
| 29 | + from the main model registry, so without this flag the registry tests will |
| 30 | + fail. |
| 31 | + """ |
| 32 | + |
| 33 | + trust_remote_code: bool = False |
| 34 | + """The ``trust_remote_code`` level required to load the model.""" |
| 35 | + |
| 36 | + |
| 37 | +# yapf: disable |
| 38 | +_TEXT_GENERATION_EXAMPLE_MODELS = { |
| 39 | + # [Decoder-only] |
| 40 | + "AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", |
| 41 | + trust_remote_code=True), |
| 42 | + "AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", |
| 43 | + trust_remote_code=True), |
| 44 | + "ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct", |
| 45 | + trust_remote_code=True), |
| 46 | + "BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B", |
| 47 | + trust_remote_code=True), |
| 48 | + "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", |
| 49 | + trust_remote_code=True), |
| 50 | + "BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"), |
| 51 | + # ChatGLMModel supports multimodal |
| 52 | + "CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01", |
| 53 | + trust_remote_code=True), |
| 54 | + "DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"), |
| 55 | + "DeciLMForCausalLM": _HfExamplesInfo("Deci/DeciLM-7B-instruct", |
| 56 | + trust_remote_code=True), |
| 57 | + "DeepseekForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-llm-7b-chat"), |
| 58 | + "DeepseekV2ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V2-Lite-Chat", # noqa: E501 |
| 59 | + trust_remote_code=True), |
| 60 | + "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501 |
| 61 | + "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), |
| 62 | + "GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"), |
| 63 | + "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), |
| 64 | + "GPT2LMHeadModel": _HfExamplesInfo("gpt2"), |
| 65 | + "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"), |
| 66 | + "GPTJForCausalLM": _HfExamplesInfo("EleutherAI/gpt-j-6b"), |
| 67 | + "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"), |
| 68 | + "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), |
| 69 | + "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), |
| 70 | + "InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b", |
| 71 | + trust_remote_code=True), |
| 72 | + "InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b", |
| 73 | + trust_remote_code=True), |
| 74 | + "InternLM2VEForCausalLM": _HfExamplesInfo("OpenGVLab/Mono-InternVL-2B", |
| 75 | + trust_remote_code=True), |
| 76 | + "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), |
| 77 | + "JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini"), |
| 78 | + "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Meta-Llama-3-8B"), |
| 79 | + "LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf", |
| 80 | + is_available_online=False), |
| 81 | + "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), |
| 82 | + "FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501 |
| 83 | + "MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16", |
| 84 | + trust_remote_code=True), |
| 85 | + "MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B", |
| 86 | + trust_remote_code=True), |
| 87 | + "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), |
| 88 | + "MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501 |
| 89 | + "QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501 |
| 90 | + "MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False), |
| 91 | + "MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"), |
| 92 | + "NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"), |
| 93 | + "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), |
| 94 | + "OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"), |
| 95 | + "OPTForCausalLM": _HfExamplesInfo("facebook/opt-iml-max-1.3b"), |
| 96 | + "OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat", |
| 97 | + trust_remote_code=True), |
| 98 | + "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), |
| 99 | + "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), |
| 100 | + "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), |
| 101 | + "Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct", |
| 102 | + trust_remote_code=True), |
| 103 | + "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", |
| 104 | + trust_remote_code=True), |
| 105 | + # QWenLMHeadModel supports multimodal |
| 106 | + "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct"), |
| 107 | + "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), |
| 108 | + "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b", |
| 109 | + is_available_online=False), |
| 110 | + "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501 |
| 111 | + is_available_online=False), |
| 112 | + "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), |
| 113 | + "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), |
| 114 | + "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"), |
| 115 | + "XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat", |
| 116 | + is_available_online=False, |
| 117 | + trust_remote_code=True), |
| 118 | + # [Encoder-decoder] |
| 119 | + "BartModel": _HfExamplesInfo("facebook/bart-base"), |
| 120 | + "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), |
| 121 | + # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer |
| 122 | + # Therefore, we borrow the BartTokenizer from the original Bart model |
| 123 | + "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 |
| 124 | + tokenizer="facebook/bart-base", |
| 125 | + trust_remote_code=True), # noqa: E501 |
| 126 | +} |
| 127 | + |
| 128 | +_EMBEDDING_EXAMPLE_MODELS = { |
| 129 | + # [Text-only] |
| 130 | + "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), |
| 131 | + "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), |
| 132 | + "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), |
| 133 | + "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), |
| 134 | + "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 |
| 135 | + # [Multimodal] |
| 136 | + "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), |
| 137 | + "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", |
| 138 | + trust_remote_code=True), |
| 139 | + "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 |
| 140 | +} |
| 141 | + |
| 142 | +_MULTIMODAL_EXAMPLE_MODELS = { |
| 143 | + # [Decoder-only] |
| 144 | + "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501 |
| 145 | + "ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501 |
| 146 | + "ChatGLMModel": _HfExamplesInfo("THUDM/glm-4v-9b", |
| 147 | + extras={"text_only": "THUDM/chatglm3-6b"}, |
| 148 | + trust_remote_code=True), |
| 149 | + "ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b", |
| 150 | + is_available_online=False), |
| 151 | + "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), |
| 152 | + "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"), |
| 153 | + "InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B", |
| 154 | + trust_remote_code=True), |
| 155 | + "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3"), # noqa: E501 |
| 156 | + "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", |
| 157 | + extras={"mistral": "mistral-community/pixtral-12b"}), # noqa: E501 |
| 158 | + "LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501 |
| 159 | + "LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501 |
| 160 | + "LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 |
| 161 | + "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", |
| 162 | + trust_remote_code=True), |
| 163 | + "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", |
| 164 | + trust_remote_code=True), |
| 165 | + "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", |
| 166 | + trust_remote_code=True), |
| 167 | + "PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-pt-224"), # noqa: E501 |
| 168 | + "Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct", |
| 169 | + trust_remote_code=True), |
| 170 | + "PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501 |
| 171 | + tokenizer_mode="mistral"), |
| 172 | + "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-VL-Chat", |
| 173 | + extras={"text_only": "Qwen/Qwen-7B-Chat"}, # noqa: E501 |
| 174 | + trust_remote_code=True), |
| 175 | + "Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501 |
| 176 | + "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 |
| 177 | + "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3"), |
| 178 | + # [Encoder-decoder] |
| 179 | + "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 |
| 180 | +} |
| 181 | + |
| 182 | +_SPECULATIVE_DECODING_EXAMPLE_MODELS = { |
| 183 | + "EAGLEModel": _HfExamplesInfo("JackFram/llama-68m", |
| 184 | + speculative_model="abhigoyal/vllm-eagle-llama-68m-random"), # noqa: E501 |
| 185 | + "MedusaModel": _HfExamplesInfo("JackFram/llama-68m", |
| 186 | + speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501 |
| 187 | + "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m", |
| 188 | + speculative_model="ibm-fms/llama-160m-accelerator"), # noqa: E501 |
| 189 | +} |
| 190 | + |
| 191 | +_EXAMPLE_MODELS = { |
| 192 | + **_TEXT_GENERATION_EXAMPLE_MODELS, |
| 193 | + **_EMBEDDING_EXAMPLE_MODELS, |
| 194 | + **_MULTIMODAL_EXAMPLE_MODELS, |
| 195 | + **_SPECULATIVE_DECODING_EXAMPLE_MODELS, |
| 196 | +} |
| 197 | + |
| 198 | + |
| 199 | +class HfExampleModels: |
| 200 | + def __init__(self, hf_models: Mapping[str, _HfExamplesInfo]) -> None: |
| 201 | + super().__init__() |
| 202 | + |
| 203 | + self.hf_models = hf_models |
| 204 | + |
| 205 | + def get_supported_archs(self) -> AbstractSet[str]: |
| 206 | + return self.hf_models.keys() |
| 207 | + |
| 208 | + def get_hf_info(self, model_arch: str) -> _HfExamplesInfo: |
| 209 | + return self.hf_models[model_arch] |
| 210 | + |
| 211 | + |
| 212 | +HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) |
0 commit comments