8
8
9
9
import pytest
10
10
11
- from vllm .config import config
11
+ from vllm .config import CompilationConfig , config
12
12
from vllm .engine .arg_utils import (EngineArgs , contains_type , get_kwargs ,
13
13
get_type , is_not_builtin , is_type ,
14
14
literal_to_kwargs , nullable_kvs ,
15
- optional_type )
15
+ optional_type , parse_type )
16
16
from vllm .utils import FlexibleArgumentParser
17
17
18
18
19
19
@pytest .mark .parametrize (("type" , "value" , "expected" ), [
20
20
(int , "42" , 42 ),
21
- (int , "None" , None ),
22
21
(float , "3.14" , 3.14 ),
23
- (float , "None" , None ),
24
22
(str , "Hello World!" , "Hello World!" ),
25
- (str , "None" , None ),
26
23
(json .loads , '{"foo":1,"bar":2}' , {
27
24
"foo" : 1 ,
28
25
"bar" : 2
31
28
"foo" : 1 ,
32
29
"bar" : 2
33
30
}),
34
- (json .loads , "None" , None ),
35
31
])
36
- def test_optional_type (type , value , expected ):
37
- optional_type_func = optional_type (type )
32
+ def test_parse_type (type , value , expected ):
33
+ parse_type_func = parse_type (type )
38
34
context = nullcontext ()
39
35
if value == "foo=1,bar=2" :
40
36
context = pytest .warns (DeprecationWarning )
41
37
with context :
42
- assert optional_type_func (value ) == expected
38
+ assert parse_type_func (value ) == expected
39
+
40
+
41
+ def test_optional_type ():
42
+ optional_type_func = optional_type (int )
43
+ assert optional_type_func ("None" ) is None
44
+ assert optional_type_func ("42" ) == 42
43
45
44
46
45
47
@pytest .mark .parametrize (("type_hint" , "type" , "expected" ), [
@@ -89,7 +91,40 @@ def test_literal_to_kwargs(type_hints, expected):
89
91
90
92
@config
91
93
@dataclass
92
- class DummyConfigClass :
94
+ class NestedConfig :
95
+ field : int = 1
96
+ """field"""
97
+
98
+
99
+ @config
100
+ @dataclass
101
+ class FromCliConfig1 :
102
+ field : int = 1
103
+ """field"""
104
+
105
+ @classmethod
106
+ def from_cli (cls , cli_value : str ):
107
+ inst = cls (** json .loads (cli_value ))
108
+ inst .field += 1
109
+ return inst
110
+
111
+
112
+ @config
113
+ @dataclass
114
+ class FromCliConfig2 :
115
+ field : int = 1
116
+ """field"""
117
+
118
+ @classmethod
119
+ def from_cli (cls , cli_value : str ):
120
+ inst = cls (** json .loads (cli_value ))
121
+ inst .field += 2
122
+ return inst
123
+
124
+
125
+ @config
126
+ @dataclass
127
+ class DummyConfig :
93
128
regular_bool : bool = True
94
129
"""Regular bool with default True"""
95
130
optional_bool : Optional [bool ] = None
@@ -108,18 +143,24 @@ class DummyConfigClass:
108
143
"""Literal of literals with default 1"""
109
144
json_tip : dict = field (default_factory = dict )
110
145
"""Dict which will be JSON in CLI"""
146
+ nested_config : NestedConfig = field (default_factory = NestedConfig )
147
+ """Nested config"""
148
+ from_cli_config1 : FromCliConfig1 = field (default_factory = FromCliConfig1 )
149
+ """Config with from_cli method"""
150
+ from_cli_config2 : FromCliConfig2 = field (default_factory = FromCliConfig2 )
151
+ """Different config with from_cli method"""
111
152
112
153
113
154
@pytest .mark .parametrize (("type_hint" , "expected" ), [
114
155
(int , False ),
115
- (DummyConfigClass , True ),
156
+ (DummyConfig , True ),
116
157
])
117
158
def test_is_not_builtin (type_hint , expected ):
118
159
assert is_not_builtin (type_hint ) == expected
119
160
120
161
121
162
def test_get_kwargs ():
122
- kwargs = get_kwargs (DummyConfigClass )
163
+ kwargs = get_kwargs (DummyConfig )
123
164
print (kwargs )
124
165
125
166
# bools should not have their type set
@@ -142,6 +183,11 @@ def test_get_kwargs():
142
183
# dict should have json tip in help
143
184
json_tip = "\n \n Should be a valid JSON string."
144
185
assert kwargs ["json_tip" ]["help" ].endswith (json_tip )
186
+ # nested config should should construct the nested config
187
+ assert kwargs ["nested_config" ]["type" ]('{"field": 2}' ) == NestedConfig (2 )
188
+ # from_cli configs should be constructed with the correct method
189
+ assert kwargs ["from_cli_config1" ]["type" ]('{"field": 2}' ).field == 3
190
+ assert kwargs ["from_cli_config2" ]["type" ]('{"field": 2}' ).field == 4
145
191
146
192
147
193
@pytest .mark .parametrize (("arg" , "expected" ), [
@@ -177,7 +223,7 @@ def test_compilation_config():
177
223
178
224
# default value
179
225
args = parser .parse_args ([])
180
- assert args .compilation_config is None
226
+ assert args .compilation_config == CompilationConfig ()
181
227
182
228
# set to O3
183
229
args = parser .parse_args (["-O3" ])
@@ -194,15 +240,15 @@ def test_compilation_config():
194
240
# set to string form of a dict
195
241
args = parser .parse_args ([
196
242
"--compilation-config" ,
197
- "{' level' : 3, ' cudagraph_capture_sizes' : [1, 2, 4, 8]}" ,
243
+ '{" level" : 3, " cudagraph_capture_sizes" : [1, 2, 4, 8]}' ,
198
244
])
199
245
assert (args .compilation_config .level == 3 and
200
246
args .compilation_config .cudagraph_capture_sizes == [1 , 2 , 4 , 8 ])
201
247
202
248
# set to string form of a dict
203
249
args = parser .parse_args ([
204
250
"--compilation-config="
205
- "{' level' : 3, ' cudagraph_capture_sizes' : [1, 2, 4, 8]}" ,
251
+ '{" level" : 3, " cudagraph_capture_sizes" : [1, 2, 4, 8]}' ,
206
252
])
207
253
assert (args .compilation_config .level == 3 and
208
254
args .compilation_config .cudagraph_capture_sizes == [1 , 2 , 4 , 8 ])
0 commit comments