Skip to content

Commit 6b46c4b

Browse files
authored
Add Nvidia ModelOpt config adaptation (#19815)
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent d978410 commit 6b46c4b

File tree

3 files changed

+287
-32
lines changed

3 files changed

+287
-32
lines changed

tests/quantization/test_modelopt.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Test ModelOpt quantization method setup and weight loading.
4+
5+
Run `pytest tests/quantization/test_modelopt.py`.
6+
"""
7+
8+
import os
9+
10+
import pytest
11+
import torch
12+
13+
from tests.quantization.utils import is_quant_method_supported
14+
from vllm.platforms import current_platform
15+
16+
17+
@pytest.fixture(scope="function", autouse=True)
18+
def use_v0_only(monkeypatch):
19+
"""
20+
This module relies on V0 internals, so set VLLM_USE_V1=0.
21+
"""
22+
if not current_platform.is_cpu():
23+
monkeypatch.setenv('VLLM_USE_V1', '0')
24+
25+
26+
@pytest.mark.skipif(not is_quant_method_supported("modelopt"),
27+
reason="ModelOpt FP8 is not supported on this GPU type.")
28+
def test_modelopt_fp8_checkpoint_setup(vllm_runner):
29+
"""Test ModelOpt FP8 checkpoint loading and structure validation."""
30+
# TODO: provide a small publically available test checkpoint
31+
model_path = ("/home/scratch.omniml_data_1/zhiyu/ckpts/test_ckpts/"
32+
"TinyLlama-1.1B-Chat-v1.0-fp8-0710")
33+
34+
# Skip test if checkpoint doesn't exist
35+
if not os.path.exists(model_path):
36+
pytest.skip(f"Test checkpoint not found at {model_path}. "
37+
"This test requires a local ModelOpt FP8 checkpoint.")
38+
39+
with vllm_runner(model_path, quantization="modelopt",
40+
enforce_eager=True) as llm:
41+
42+
def check_model(model):
43+
layer = model.model.layers[0]
44+
45+
qkv_proj = layer.self_attn.qkv_proj
46+
o_proj = layer.self_attn.o_proj
47+
gate_up_proj = layer.mlp.gate_up_proj
48+
down_proj = layer.mlp.down_proj
49+
50+
# Check that ModelOpt quantization method is properly applied
51+
from vllm.model_executor.layers.quantization.modelopt import (
52+
ModelOptFp8LinearMethod)
53+
assert isinstance(qkv_proj.quant_method, ModelOptFp8LinearMethod)
54+
assert isinstance(o_proj.quant_method, ModelOptFp8LinearMethod)
55+
assert isinstance(gate_up_proj.quant_method,
56+
ModelOptFp8LinearMethod)
57+
assert isinstance(down_proj.quant_method, ModelOptFp8LinearMethod)
58+
59+
# Check weight dtype is FP8
60+
assert qkv_proj.weight.dtype == torch.float8_e4m3fn
61+
assert o_proj.weight.dtype == torch.float8_e4m3fn
62+
assert gate_up_proj.weight.dtype == torch.float8_e4m3fn
63+
assert down_proj.weight.dtype == torch.float8_e4m3fn
64+
65+
# Check scales are present and have correct dtype
66+
assert hasattr(qkv_proj, 'weight_scale')
67+
assert hasattr(qkv_proj, 'input_scale')
68+
assert qkv_proj.weight_scale.dtype == torch.float32
69+
assert qkv_proj.input_scale.dtype == torch.float32
70+
71+
assert hasattr(o_proj, 'weight_scale')
72+
assert hasattr(o_proj, 'input_scale')
73+
assert o_proj.weight_scale.dtype == torch.float32
74+
assert o_proj.input_scale.dtype == torch.float32
75+
76+
assert hasattr(gate_up_proj, 'weight_scale')
77+
assert hasattr(gate_up_proj, 'input_scale')
78+
assert gate_up_proj.weight_scale.dtype == torch.float32
79+
assert gate_up_proj.input_scale.dtype == torch.float32
80+
81+
assert hasattr(down_proj, 'weight_scale')
82+
assert hasattr(down_proj, 'input_scale')
83+
assert down_proj.weight_scale.dtype == torch.float32
84+
assert down_proj.input_scale.dtype == torch.float32
85+
86+
llm.apply_model(check_model)
87+
88+
# Run a simple generation test to ensure the model works
89+
output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
90+
assert output
91+
print(f"ModelOpt FP8 output: {output}")

vllm/config.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -346,11 +346,11 @@ class ModelConfig:
346346
"""Maximum number of data items per modality per prompt. Only applicable
347347
for multimodal models."""
348348
interleave_mm_strings: bool = False
349-
"""Enable fully interleaved support for multimodal prompts, while using
349+
"""Enable fully interleaved support for multimodal prompts, while using
350350
--chat-template-content-format=string. Defaults to False."""
351351
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
352-
"""Additional args passed to process media inputs, keyed by modalities.
353-
For example, to set num_frames for video, set
352+
"""Additional args passed to process media inputs, keyed by modalities.
353+
For example, to set num_frames for video, set
354354
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
355355
use_async_output_proc: bool = True
356356
"""Whether to use async output processor."""
@@ -1000,9 +1000,13 @@ def _verify_quantization(self) -> None:
10001000
quant_cfg = self._parse_quant_hf_config()
10011001

10021002
if quant_cfg is not None:
1003+
# Use the community standard 'quant_method'
10031004
quant_method = quant_cfg.get("quant_method", "").lower()
1005+
1006+
# Normalize library names
10041007
quant_method = quant_method.replace("compressed_tensors",
10051008
"compressed-tensors")
1009+
10061010
quant_cfg["quant_method"] = quant_method
10071011

10081012
# Quantization methods which are overrides (i.e. they have a
@@ -1017,6 +1021,8 @@ def _verify_quantization(self) -> None:
10171021
"awq_marlin",
10181022
"ipex",
10191023
"moe_wna16",
1024+
"modelopt",
1025+
"modelopt_fp4",
10201026
]
10211027
quantization_methods = [
10221028
q for q in supported_quantization if q not in overrides
@@ -3185,8 +3191,8 @@ class MultiModalConfig:
31853191
"""
31863192

31873193
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
3188-
"""Additional args passed to process media inputs, keyed by modalities.
3189-
For example, to set num_frames for video, set
3194+
"""Additional args passed to process media inputs, keyed by modalities.
3195+
For example, to set num_frames for video, set
31903196
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
31913197

31923198
mm_processor_kwargs: Optional[dict[str, object]] = None
@@ -4086,7 +4092,7 @@ class CompilationConfig:
40864092
- True: inductor compilation is used (custom_ops disabled by default).
40874093
One graph for symbolic shape and one graph per size in compile_sizes
40884094
are compiled using configurations in inductor_compile_config.
4089-
4095+
40904096
This setting is ignored if level<PIECEWISE."""
40914097
compile_sizes: Optional[list[Union[int, str]]] = None
40924098
"""Sizes to compile for inductor. In addition
@@ -4385,7 +4391,7 @@ class VllmConfig:
43854391
43864392
As a shorthand, `-O<n>` can be used to directly specify the compilation
43874393
level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`).
4388-
Currently, -O <n> and -O=<n> are supported as well but this will likely be
4394+
Currently, -O <n> and -O=<n> are supported as well but this will likely be
43894395
removed in favor of clearer -O<n> syntax in the future.
43904396
43914397
NOTE: level 0 is the default level without any optimization. level 1 and 2

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 183 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -75,20 +75,64 @@ def get_min_capability(cls) -> int:
7575
def get_config_filenames(cls) -> list[str]:
7676
return ["hf_quant_config.json"]
7777

78+
@classmethod
79+
def override_quantization_method(
80+
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
81+
"""Detect if this ModelOpt config should be used based on
82+
quantization config."""
83+
84+
if hf_quant_cfg is None:
85+
return None
86+
87+
# Use the community standard 'quant_method'
88+
quant_method = hf_quant_cfg.get("quant_method", "").lower()
89+
90+
# Only proceed if the method is explicitly "modelopt"
91+
if quant_method != "modelopt":
92+
return None
93+
94+
# Look for ModelOpt-specific config structure
95+
if "quantization" in hf_quant_cfg:
96+
quant_config = hf_quant_cfg["quantization"]
97+
if isinstance(quant_config, dict):
98+
quant_algo = quant_config.get("quant_algo", "")
99+
if "FP8" in quant_algo:
100+
return "modelopt"
101+
else:
102+
# Check for compressed-tensors style config with specific quant_algo
103+
quant_algo = hf_quant_cfg.get("quant_algo", "")
104+
if isinstance(quant_algo, str) and "FP8" in quant_algo:
105+
return "modelopt"
106+
107+
return None
108+
78109
@classmethod
79110
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
80-
quant_config = cls.get_from_keys(config, ["quantization"])
81-
quant_method = quant_config["quant_algo"]
82-
kv_cache_quant_method = cls.get_from_keys(
83-
config, ["quantization"]).get("kv_cache_quant_algo")
84-
exclude_modules = cls.get_from_keys(
85-
config, ["quantization"]).get("exclude_modules")
111+
# Handle both ModelOpt format and compressed-tensors style format
112+
if "quantization" in config:
113+
# ModelOpt format: {"quantization": {"quant_algo": "..."}}
114+
quant_config = cls.get_from_keys(config, ["quantization"])
115+
if not isinstance(quant_config, dict):
116+
raise ValueError(
117+
"Expected 'quantization' to be a dictionary in config")
118+
quant_method = quant_config.get("quant_algo", "")
119+
if not quant_method:
120+
raise ValueError("Missing 'quant_algo' in quantization config")
121+
kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
122+
exclude_modules = quant_config.get("exclude_modules")
123+
else:
124+
# Compressed-tensors style format:
125+
# {"quant_algo": "...", "quant_method": "modelopt"}
126+
quant_method = config.get("quant_algo", "")
127+
kv_cache_quant_method = config.get("kv_cache_quant_algo")
128+
exclude_modules = config.get("exclude_modules")
86129

87130
if quant_method not in QUANT_ALGOS:
88-
raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
89-
" quantizations in vLLM. Please check the "
90-
"`hf_quant_config.json` file for your model's "
91-
"quant configuration.")
131+
raise ValueError(
132+
f"ModelOpt currently only supports: {QUANT_ALGOS} "
133+
"quantizations in vLLM. Please check the "
134+
"`hf_quant_config.json` file for your model's "
135+
"quant configuration.")
92136
is_checkpoint_fp8_serialized = ("FP8" in quant_method)
93137

94138
return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method,
@@ -434,7 +478,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
434478
def __init__(
435479
self,
436480
is_checkpoint_nvfp4_serialized: bool,
437-
kv_cache_quant_algo: str,
481+
kv_cache_quant_algo: Optional[str],
438482
exclude_modules: list[str],
439483
group_size: int = 16,
440484
) -> None:
@@ -465,24 +509,138 @@ def get_min_capability(cls) -> int:
465509
def get_config_filenames(cls) -> list[str]:
466510
return ["hf_quant_config.json"]
467511

512+
@classmethod
513+
def override_quantization_method(
514+
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
515+
"""Detect if this ModelOpt FP4 config should be used based on
516+
quantization config."""
517+
if hf_quant_cfg is None:
518+
return None
519+
520+
# Use the community standard 'quant_method'
521+
quant_method = hf_quant_cfg.get("quant_method", "").lower()
522+
523+
# Only proceed if the method is explicitly "modelopt"
524+
if quant_method != "modelopt":
525+
return None
526+
527+
# Look for ModelOpt-specific config structure
528+
if "quantization" in hf_quant_cfg:
529+
quant_config = hf_quant_cfg["quantization"]
530+
if isinstance(quant_config, dict):
531+
quant_algo = quant_config.get("quant_algo", "")
532+
if "NVFP4" in quant_algo:
533+
return "modelopt_fp4"
534+
else:
535+
# Check for compressed-tensors style config with specific
536+
# quant_algo field
537+
quant_algo = hf_quant_cfg.get("quant_algo", "")
538+
if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
539+
return "modelopt_fp4"
540+
541+
return None
542+
468543
@classmethod
469544
def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
470-
quant_config = cls.get_from_keys(config, ["quantization"])
471-
quant_method = quant_config["quant_algo"]
545+
# Handle both traditional ModelOpt format and compressed-tensors
546+
# style format
547+
if "quantization" in config:
548+
# Traditional ModelOpt format:
549+
# {"quantization": {"quant_algo": "..."}}
550+
quant_config = cls.get_from_keys(config, ["quantization"])
551+
if not isinstance(quant_config, dict):
552+
raise ValueError(
553+
"Expected 'quantization' to be a dictionary in config")
554+
555+
quant_method = quant_config.get("quant_algo", "")
556+
if not quant_method:
557+
raise ValueError("Missing 'quant_algo' in quantization config")
558+
559+
# Handle kv_cache_quant_algo with proper type validation
560+
kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
561+
if kv_cache_quant_algo_raw is None:
562+
# No KV cache quantization by default
563+
kv_cache_quant_algo = None
564+
elif isinstance(kv_cache_quant_algo_raw, str):
565+
kv_cache_quant_algo = kv_cache_quant_algo_raw
566+
else:
567+
raise ValueError(f"kv_cache_quant_algo must be a string, got "
568+
f"{type(kv_cache_quant_algo_raw)}")
569+
570+
# Handle group_size with proper type validation
571+
group_size_raw = quant_config.get("group_size")
572+
if group_size_raw is None:
573+
group_size = 16 # Default value
574+
elif isinstance(group_size_raw, int):
575+
group_size = group_size_raw
576+
else:
577+
try:
578+
group_size = int(group_size_raw)
579+
except (ValueError, TypeError):
580+
raise ValueError(f"group_size must be an integer, got "
581+
f"{type(group_size_raw)}") from None
582+
583+
exclude_modules = quant_config.get("exclude_modules", [])
584+
if not isinstance(exclude_modules, list):
585+
raise ValueError(f"exclude_modules must be a list, got "
586+
f"{type(exclude_modules)}")
587+
else:
588+
# Compressed-tensors style format:
589+
# {"quant_algo": "...", "quant_method": "modelopt"}
590+
quant_method = config.get("quant_algo", "")
591+
592+
# Handle kv_cache_quant_algo with proper type validation
593+
kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
594+
if kv_cache_quant_algo_raw is None:
595+
# No KV cache quantization by default
596+
kv_cache_quant_algo = None
597+
elif isinstance(kv_cache_quant_algo_raw, str):
598+
kv_cache_quant_algo = kv_cache_quant_algo_raw
599+
else:
600+
raise ValueError(f"kv_cache_quant_algo must be a string, got "
601+
f"{type(kv_cache_quant_algo_raw)}")
602+
603+
# Handle group_size with proper type validation
604+
group_size_raw = config.get("group_size")
605+
if group_size_raw is None:
606+
group_size = 16 # Default value
607+
elif isinstance(group_size_raw, int):
608+
group_size = group_size_raw
609+
else:
610+
try:
611+
group_size = int(group_size_raw)
612+
except (ValueError, TypeError):
613+
raise ValueError(f"group_size must be an integer, got "
614+
f"{type(group_size_raw)}") from None
615+
616+
exclude_modules = config.get("exclude_modules", [])
617+
if not isinstance(exclude_modules, list):
618+
raise ValueError(f"exclude_modules must be a list, got "
619+
f"{type(exclude_modules)}")
620+
472621
if quant_method not in QUANT_ALGOS:
473-
raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
474-
" quantizations in vLLM. Please check the "
475-
"`hf_quant_config.json` file for your model's "
476-
"quant configuration.")
622+
raise ValueError(
623+
f"ModelOpt currently only supports: {QUANT_ALGOS} "
624+
"quantizations in vLLM. Please check the "
625+
"`hf_quant_config.json` file for your model's "
626+
"quant configuration.")
477627
is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method)
478-
if ("group_size" and "kv_cache_quant_algo"
479-
and "exclude_modules") not in quant_config:
480-
raise ValueError("NVFP4 quantization requires group size and "
481-
"kv_cache_quant_algo specified in "
482-
"hf_quant_config.json")
483-
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
484-
group_size = quant_config["group_size"]
485-
exclude_modules = quant_config["exclude_modules"]
628+
629+
# For FP4, these fields are required
630+
if is_checkpoint_nvfp4_serialized and "quantization" in config:
631+
# Check if required fields are present in the quantization config
632+
quant_config = config["quantization"]
633+
required_fields = [
634+
"group_size", "kv_cache_quant_algo", "exclude_modules"
635+
]
636+
missing_fields = [
637+
field for field in required_fields if field not in quant_config
638+
]
639+
if missing_fields:
640+
raise ValueError(
641+
f"NVFP4 quantization requires the following fields in "
642+
f"hf_quant_config.json: {missing_fields}")
643+
486644
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
487645
exclude_modules, group_size)
488646

0 commit comments

Comments
 (0)