Skip to content

Commit 4b2ed79

Browse files
authored
Improve configs - the rest! (#17562)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 7e35711 commit 4b2ed79

File tree

14 files changed

+461
-345
lines changed

14 files changed

+461
-345
lines changed

tests/compile/test_full_graph.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from tests.quantization.utils import is_quant_method_supported
1111
from vllm import LLM, SamplingParams
12-
from vllm.config import CompilationConfig, CompilationLevel
12+
from vllm.config import CompilationConfig, CompilationLevel, PassConfig
1313
from vllm.platforms import current_platform
1414

1515
from ..utils import create_new_process_for_each_test
@@ -95,9 +95,6 @@ def test_full_graph(
9595
run_model(optimization_level, model, model_kwargs)
9696

9797

98-
PassConfig = CompilationConfig.PassConfig
99-
100-
10198
# TODO(luka) add other supported compilation config scenarios here
10299
@pytest.mark.parametrize(
103100
"compilation_config, model_info",

tests/compile/test_functionalization.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
kFp8DynamicTokenSym, kFp8StaticTensorSym)
1212
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
1313
from vllm.compilation.noop_elimination import NoOpEliminationPass
14-
from vllm.config import CompilationConfig, VllmConfig
14+
from vllm.config import CompilationConfig, PassConfig, VllmConfig
1515

1616
from .backend import TestBackend
1717

@@ -53,9 +53,8 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
5353
torch.set_default_device("cuda")
5454

5555
vllm_config = VllmConfig()
56-
vllm_config.compilation_config = CompilationConfig(pass_config= \
57-
CompilationConfig.PassConfig(enable_fusion=do_fusion,
58-
enable_noop=True))
56+
vllm_config.compilation_config = CompilationConfig(
57+
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
5958
noop_pass = NoOpEliminationPass(vllm_config)
6059
fusion_pass = FusionPass.instance(vllm_config)
6160
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)

tests/compile/test_fusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
FusionPass, QuantKey)
1010
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
1111
from vllm.compilation.noop_elimination import NoOpEliminationPass
12-
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
12+
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
13+
VllmConfig)
1314
from vllm.model_executor.layers.layernorm import RMSNorm
1415
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
1516
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
@@ -78,8 +79,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
7879
vllm_config = VllmConfig(compilation_config=CompilationConfig(
7980
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
8081
vllm_config.compilation_config.pass_config = \
81-
CompilationConfig.PassConfig(enable_fusion=True,
82-
enable_noop=True)
82+
PassConfig(enable_fusion=True, enable_noop=True)
8383
with vllm.config.set_current_vllm_config(vllm_config):
8484
# Reshape pass is needed for the fusion pass to work
8585
noop_pass = NoOpEliminationPass(vllm_config)

tests/compile/test_sequence_parallelism.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
find_specified_fn_maybe, is_func)
1111
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
1212
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
13-
VllmConfig)
13+
PassConfig, VllmConfig)
1414
from vllm.distributed import tensor_model_parallel_all_reduce
1515
from vllm.distributed.parallel_state import (init_distributed_environment,
1616
initialize_model_parallel)
@@ -126,9 +126,8 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
126126

127127
# configure vllm config for SequenceParallelismPass
128128
vllm_config = VllmConfig()
129-
vllm_config.compilation_config = CompilationConfig(
130-
pass_config=CompilationConfig.PassConfig(
131-
enable_sequence_parallelism=True, ), )
129+
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
130+
enable_sequence_parallelism=True))
132131
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
133132

134133
# this is a fake model name to construct the model config

tests/compile/test_silu_mul_quant_fusion.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from vllm._custom_ops import scaled_fp8_quant
77
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
88
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
9-
from vllm.config import CompilationConfig, VllmConfig
9+
from vllm.config import CompilationConfig, PassConfig, VllmConfig
1010
from vllm.model_executor.layers.activation import SiluAndMul
1111

1212
from .backend import TestBackend
@@ -36,8 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
3636
# Reshape pass is needed for the fusion pass to work
3737
config = VllmConfig()
3838
config.compilation_config = CompilationConfig(
39-
pass_config=CompilationConfig.PassConfig(enable_fusion=True,
40-
enable_reshape=True))
39+
pass_config=PassConfig(enable_fusion=True, enable_reshape=True))
4140
fusion_pass = ActivationQuantFusionPass(config)
4241

4342
backend = TestBackend(fusion_pass)

tests/distributed/test_sequence_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _compare_sp(
206206
'compile_sizes': [4, 8],
207207
'splitting_ops': [],
208208
'pass_config': {
209-
'enable_sequence_parallism': sp_enabled,
209+
'enable_sequence_parallelism': sp_enabled,
210210
'enable_noop': True,
211211
'enable_fusion': True,
212212
},
@@ -223,7 +223,7 @@ def _compare_sp(
223223
"--distributed-executor-backend",
224224
distributed_backend,
225225
"--compilation_config",
226-
str(compilation_config),
226+
json.dumps(compilation_config),
227227
]
228228

229229
tp_env = {

tests/engine/test_arg_utils.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,18 @@
88

99
import pytest
1010

11-
from vllm.config import config
11+
from vllm.config import CompilationConfig, config
1212
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
1313
get_type, is_not_builtin, is_type,
1414
literal_to_kwargs, nullable_kvs,
15-
optional_type)
15+
optional_type, parse_type)
1616
from vllm.utils import FlexibleArgumentParser
1717

1818

1919
@pytest.mark.parametrize(("type", "value", "expected"), [
2020
(int, "42", 42),
21-
(int, "None", None),
2221
(float, "3.14", 3.14),
23-
(float, "None", None),
2422
(str, "Hello World!", "Hello World!"),
25-
(str, "None", None),
2623
(json.loads, '{"foo":1,"bar":2}', {
2724
"foo": 1,
2825
"bar": 2
@@ -31,15 +28,20 @@
3128
"foo": 1,
3229
"bar": 2
3330
}),
34-
(json.loads, "None", None),
3531
])
36-
def test_optional_type(type, value, expected):
37-
optional_type_func = optional_type(type)
32+
def test_parse_type(type, value, expected):
33+
parse_type_func = parse_type(type)
3834
context = nullcontext()
3935
if value == "foo=1,bar=2":
4036
context = pytest.warns(DeprecationWarning)
4137
with context:
42-
assert optional_type_func(value) == expected
38+
assert parse_type_func(value) == expected
39+
40+
41+
def test_optional_type():
42+
optional_type_func = optional_type(int)
43+
assert optional_type_func("None") is None
44+
assert optional_type_func("42") == 42
4345

4446

4547
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
@@ -89,7 +91,40 @@ def test_literal_to_kwargs(type_hints, expected):
8991

9092
@config
9193
@dataclass
92-
class DummyConfigClass:
94+
class NestedConfig:
95+
field: int = 1
96+
"""field"""
97+
98+
99+
@config
100+
@dataclass
101+
class FromCliConfig1:
102+
field: int = 1
103+
"""field"""
104+
105+
@classmethod
106+
def from_cli(cls, cli_value: str):
107+
inst = cls(**json.loads(cli_value))
108+
inst.field += 1
109+
return inst
110+
111+
112+
@config
113+
@dataclass
114+
class FromCliConfig2:
115+
field: int = 1
116+
"""field"""
117+
118+
@classmethod
119+
def from_cli(cls, cli_value: str):
120+
inst = cls(**json.loads(cli_value))
121+
inst.field += 2
122+
return inst
123+
124+
125+
@config
126+
@dataclass
127+
class DummyConfig:
93128
regular_bool: bool = True
94129
"""Regular bool with default True"""
95130
optional_bool: Optional[bool] = None
@@ -108,18 +143,24 @@ class DummyConfigClass:
108143
"""Literal of literals with default 1"""
109144
json_tip: dict = field(default_factory=dict)
110145
"""Dict which will be JSON in CLI"""
146+
nested_config: NestedConfig = field(default_factory=NestedConfig)
147+
"""Nested config"""
148+
from_cli_config1: FromCliConfig1 = field(default_factory=FromCliConfig1)
149+
"""Config with from_cli method"""
150+
from_cli_config2: FromCliConfig2 = field(default_factory=FromCliConfig2)
151+
"""Different config with from_cli method"""
111152

112153

113154
@pytest.mark.parametrize(("type_hint", "expected"), [
114155
(int, False),
115-
(DummyConfigClass, True),
156+
(DummyConfig, True),
116157
])
117158
def test_is_not_builtin(type_hint, expected):
118159
assert is_not_builtin(type_hint) == expected
119160

120161

121162
def test_get_kwargs():
122-
kwargs = get_kwargs(DummyConfigClass)
163+
kwargs = get_kwargs(DummyConfig)
123164
print(kwargs)
124165

125166
# bools should not have their type set
@@ -142,6 +183,11 @@ def test_get_kwargs():
142183
# dict should have json tip in help
143184
json_tip = "\n\nShould be a valid JSON string."
144185
assert kwargs["json_tip"]["help"].endswith(json_tip)
186+
# nested config should should construct the nested config
187+
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
188+
# from_cli configs should be constructed with the correct method
189+
assert kwargs["from_cli_config1"]["type"]('{"field": 2}').field == 3
190+
assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4
145191

146192

147193
@pytest.mark.parametrize(("arg", "expected"), [
@@ -177,7 +223,7 @@ def test_compilation_config():
177223

178224
# default value
179225
args = parser.parse_args([])
180-
assert args.compilation_config is None
226+
assert args.compilation_config == CompilationConfig()
181227

182228
# set to O3
183229
args = parser.parse_args(["-O3"])
@@ -194,15 +240,15 @@ def test_compilation_config():
194240
# set to string form of a dict
195241
args = parser.parse_args([
196242
"--compilation-config",
197-
"{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}",
243+
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
198244
])
199245
assert (args.compilation_config.level == 3 and
200246
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
201247

202248
# set to string form of a dict
203249
args = parser.parse_args([
204250
"--compilation-config="
205-
"{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}",
251+
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
206252
])
207253
assert (args.compilation_config.level == 3 and
208254
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])

vllm/compilation/vllm_inductor_pass.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from vllm.config import CompilationConfig, VllmConfig
7+
from vllm.config import PassConfig, VllmConfig
88
# yapf: disable
99
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
1010
from vllm.distributed import (
@@ -56,10 +56,7 @@ def end_and_log(self):
5656

5757
class PrinterInductorPass(VllmInductorPass):
5858

59-
def __init__(self,
60-
name: str,
61-
config: CompilationConfig.PassConfig,
62-
always=False):
59+
def __init__(self, name: str, config: PassConfig, always=False):
6360
super().__init__(config)
6461
self.name = name
6562
self.always = always

0 commit comments

Comments
 (0)