4
4
5
5
import torch
6
6
import torch .nn as nn
7
- from transformers import PretrainedConfig
8
7
9
8
from vllm .config import ModelConfig , LoRAConfig
10
9
from vllm .model_executor .models import ModelRegistry
@@ -21,8 +20,14 @@ def _set_default_torch_dtype(dtype: torch.dtype):
21
20
torch .set_default_dtype (old_dtype )
22
21
23
22
24
- def _get_model_architecture (config : PretrainedConfig ) -> Type [nn .Module ]:
25
- architectures = getattr (config , "architectures" , [])
23
+ def _get_model_architecture (model_config : ModelConfig ) -> Type [nn .Module ]:
24
+ architectures = getattr (model_config .hf_config , "architectures" , [])
25
+ # Special handling for quantized Mixtral.
26
+ # FIXME(woosuk): This is a temporary hack.
27
+ if (model_config .quantization is not None
28
+ and "MixtralForCausalLM" in architectures ):
29
+ architectures = ["QuantMixtralForCausalLM" ]
30
+
26
31
for arch in architectures :
27
32
model_cls = ModelRegistry .load_model_cls (arch )
28
33
if model_cls is not None :
@@ -34,7 +39,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
34
39
35
40
def get_model (model_config : ModelConfig ,
36
41
lora_config : Optional [LoRAConfig ] = None ) -> nn .Module :
37
- model_class = _get_model_architecture (model_config . hf_config )
42
+ model_class = _get_model_architecture (model_config )
38
43
39
44
# Get the (maybe quantized) linear method.
40
45
linear_method = None
0 commit comments