Skip to content

Commit bb00f66

Browse files
authored
Use quantization_config in hf config (#1695)
1 parent e87557b commit bb00f66

File tree

3 files changed

+34
-10
lines changed

3 files changed

+34
-10
lines changed

vllm/config.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,30 @@ def _verify_tokenizer_mode(self) -> None:
104104

105105
def _verify_quantization(self) -> None:
106106
supported_quantization = ["awq", "squeezellm"]
107-
if self.quantization is None:
108-
return
109-
quantization = self.quantization.lower()
110-
if quantization not in supported_quantization:
111-
raise ValueError(
112-
f"Unknown quantization: {self.quantization}. Must be one of "
113-
f"{supported_quantization}.")
114-
self.quantization = quantization
107+
if self.quantization is not None:
108+
self.quantization = self.quantization.lower()
109+
110+
# Parse quantization method from the HF model config, if available.
111+
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
112+
if hf_quant_config is not None:
113+
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
114+
if self.quantization is None:
115+
self.quantization = hf_quant_method
116+
elif self.quantization != hf_quant_method:
117+
raise ValueError(
118+
"Quantization method specified in the model config "
119+
f"({hf_quant_method}) does not match the quantization "
120+
f"method specified in the `quantization` argument "
121+
f"({self.quantization}).")
122+
123+
if self.quantization is not None:
124+
if self.quantization not in supported_quantization:
125+
raise ValueError(
126+
f"Unknown quantization method: {self.quantization}. Must "
127+
f"be one of {supported_quantization}.")
128+
logger.warning(f"{self.quantization} quantization is not fully "
129+
"optimized yet. The speed can be slower than "
130+
"non-quantized models.")
115131

116132
def verify_with_parallel_config(
117133
self,

vllm/model_executor/model_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
6666
if model_config.quantization is not None:
6767
quant_config = get_quant_config(model_config.quantization,
6868
model_config.model,
69+
model_config.hf_config,
6970
model_config.download_dir)
7071
capability = torch.cuda.get_device_capability()
7172
capability = capability[0] * 10 + capability[1]

vllm/model_executor/weight_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from typing import Any, Iterator, List, Optional, Tuple
88

99
from huggingface_hub import snapshot_download
10-
from safetensors.torch import load_file, save_file, safe_open
1110
import numpy as np
11+
from safetensors.torch import load_file, save_file, safe_open
1212
import torch
13+
from transformers import PretrainedConfig
1314
from tqdm.auto import tqdm
1415

1516
from vllm.logger import init_logger
@@ -84,8 +85,15 @@ def convert_bin_to_safetensor_file(
8485
def get_quant_config(
8586
quantization: str,
8687
model_name_or_path: str,
88+
hf_config: PretrainedConfig,
8789
cache_dir: Optional[str] = None,
8890
) -> QuantizationConfig:
91+
quant_cls = get_quantization_config(quantization)
92+
# Read the quantization config from the HF model config, if available.
93+
hf_quant_config = getattr(hf_config, "quantization_config", None)
94+
if hf_quant_config is not None:
95+
return quant_cls.from_config(hf_quant_config)
96+
8997
is_local = os.path.isdir(model_name_or_path)
9098
if not is_local:
9199
# Download the config files.
@@ -98,7 +106,6 @@ def get_quant_config(
98106
hf_folder = model_name_or_path
99107
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
100108

101-
quant_cls = get_quantization_config(quantization)
102109
quant_config_files = [
103110
f for f in config_files if any(
104111
f.endswith(x) for x in quant_cls.get_config_filenames())

0 commit comments

Comments
 (0)