44
55import torch
66import torch .nn as nn
7- from transformers import PretrainedConfig
87
98from vllm .config import ModelConfig , LoRAConfig
109from vllm .model_executor .models import ModelRegistry
@@ -21,8 +20,14 @@ def _set_default_torch_dtype(dtype: torch.dtype):
2120 torch .set_default_dtype (old_dtype )
2221
2322
24- def _get_model_architecture (config : PretrainedConfig ) -> Type [nn .Module ]:
25- architectures = getattr (config , "architectures" , [])
23+ def _get_model_architecture (model_config : ModelConfig ) -> Type [nn .Module ]:
24+ architectures = getattr (model_config .hf_config , "architectures" , [])
25+ # Special handling for quantized Mixtral.
26+ # FIXME(woosuk): This is a temporary hack.
27+ if (model_config .quantization is not None
28+ and "MixtralForCausalLM" in architectures ):
29+ architectures = ["QuantMixtralForCausalLM" ]
30+
2631 for arch in architectures :
2732 model_cls = ModelRegistry .load_model_cls (arch )
2833 if model_cls is not None :
@@ -34,7 +39,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
3439
3540def get_model (model_config : ModelConfig ,
3641 lora_config : Optional [LoRAConfig ] = None ) -> nn .Module :
37- model_class = _get_model_architecture (model_config . hf_config )
42+ model_class = _get_model_architecture (model_config )
3843
3944 # Get the (maybe quantized) linear method.
4045 linear_method = None
0 commit comments