Skip to content

Commit 3dad944

Browse files
authored
Add quantized mixtral support (#2673)
1 parent 105a40f commit 3dad944

File tree

3 files changed

+422
-4
lines changed

3 files changed

+422
-4
lines changed

vllm/model_executor/model_loader.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import torch
66
import torch.nn as nn
7-
from transformers import PretrainedConfig
87

98
from vllm.config import ModelConfig, LoRAConfig
109
from vllm.model_executor.models import ModelRegistry
@@ -21,8 +20,14 @@ def _set_default_torch_dtype(dtype: torch.dtype):
2120
torch.set_default_dtype(old_dtype)
2221

2322

24-
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
25-
architectures = getattr(config, "architectures", [])
23+
def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
24+
architectures = getattr(model_config.hf_config, "architectures", [])
25+
# Special handling for quantized Mixtral.
26+
# FIXME(woosuk): This is a temporary hack.
27+
if (model_config.quantization is not None
28+
and "MixtralForCausalLM" in architectures):
29+
architectures = ["QuantMixtralForCausalLM"]
30+
2631
for arch in architectures:
2732
model_cls = ModelRegistry.load_model_cls(arch)
2833
if model_cls is not None:
@@ -34,7 +39,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
3439

3540
def get_model(model_config: ModelConfig,
3641
lora_config: Optional[LoRAConfig] = None) -> nn.Module:
37-
model_class = _get_model_architecture(model_config.hf_config)
42+
model_class = _get_model_architecture(model_config)
3843

3944
# Get the (maybe quantized) linear method.
4045
linear_method = None

vllm/model_executor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
3131
"MistralForCausalLM": ("mistral", "MistralForCausalLM"),
3232
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
33+
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
3334
# transformers's mpt class has lower case
3435
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
3536
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),

0 commit comments

Comments
 (0)