Skip to content

Commit 30df952

Browse files
sarckkProExpertProg
authored andcommitted
Fix compile tests
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent 6b70132 commit 30df952

File tree

2 files changed

+83
-41
lines changed

2 files changed

+83
-41
lines changed

tests/compile/piecewise/test_multiple_graphs.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from vllm.compilation.counter import compilation_counter
1313
from vllm.compilation.decorators import (ignore_torch_compile,
1414
support_torch_compile)
15-
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
16-
set_current_vllm_config)
17-
from vllm.forward_context import set_forward_context
15+
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
16+
VllmConfig, set_current_vllm_config)
17+
from vllm.forward_context import BatchDescriptor, set_forward_context
1818
from vllm.utils import direct_register_custom_op
1919

2020
# create a library to hold the custom op
@@ -164,16 +164,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
164164

165165

166166
@torch.inference_mode
167-
def run_model(vllm_config, model: nn.Module, inputs: torch.Tensor):
167+
def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor,
168+
cudagraph_runtime_mode: CUDAGraphMode):
168169
with set_forward_context({}, vllm_config=vllm_config):
169-
# First run is for compile
170+
# warmup for the model with cudagraph_mode NONE
170171
model(inputs)
171172

172-
# Run CUDAGraph captured sizes
173-
model(inputs[:2])
174-
model(inputs[:1])
175-
176-
output = model(inputs[:2])
173+
# simulate cudagraphs capturing
174+
with set_forward_context({},
175+
vllm_config=vllm_config,
176+
cudagraph_runtime_mode=cudagraph_runtime_mode,
177+
batch_descriptor=BatchDescriptor(
178+
num_tokens=2, )):
179+
model(inputs[:2])
180+
with set_forward_context({},
181+
vllm_config=vllm_config,
182+
cudagraph_runtime_mode=cudagraph_runtime_mode,
183+
batch_descriptor=BatchDescriptor(
184+
num_tokens=1, )):
185+
model(inputs[:1])
186+
187+
# simulate cudagraphs replay
188+
with set_forward_context({},
189+
vllm_config=vllm_config,
190+
cudagraph_runtime_mode=cudagraph_runtime_mode,
191+
batch_descriptor=BatchDescriptor(
192+
num_tokens=2, )):
193+
output = model(inputs[:2])
177194

178195
output = output.cpu()
179196
return output.cpu()
@@ -189,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
189206
splitting_ops=["silly.attention"],
190207
cudagraph_capture_sizes=[1, 2],
191208
))
209+
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
192210

193211
with set_current_vllm_config(vllm_config):
194212
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
@@ -211,11 +229,13 @@ def test_multi_graph_piecewise_compile_outputs_equal():
211229
num_cudagraph_captured=8,
212230
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
213231
):
214-
outputs.append(run_model(vllm_config, model, inputs))
232+
outputs.append(
233+
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
215234

216235
# no compile or cudagraph
217236
vllm_config = VllmConfig(compilation_config=CompilationConfig(
218237
level=CompilationLevel.NO_COMPILATION, ))
238+
cudagraph_runtime_mode = CUDAGraphMode.NONE
219239

220240
with set_current_vllm_config(vllm_config):
221241
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
@@ -230,14 +250,16 @@ def test_multi_graph_piecewise_compile_outputs_equal():
230250
num_backend_compilations=0,
231251
num_cudagraph_captured=0,
232252
):
233-
outputs.append(run_model(vllm_config, model, inputs))
253+
outputs.append(
254+
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
234255

235256
# piecewise compile without CUDA graph
236257
vllm_config = VllmConfig(compilation_config=CompilationConfig(
237258
level=CompilationLevel.PIECEWISE,
238259
use_cudagraph=False,
239260
splitting_ops=["silly.attention"],
240261
))
262+
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
241263

242264
with set_current_vllm_config(vllm_config):
243265
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
@@ -252,7 +274,8 @@ def test_multi_graph_piecewise_compile_outputs_equal():
252274
num_backend_compilations=4,
253275
num_cudagraph_captured=0, # no cudagraph captured
254276
):
255-
outputs.append(run_model(vllm_config, model, inputs))
277+
outputs.append(
278+
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
256279

257280
# Generally don't expect outputs with and without inductor
258281
# to be bitwise equivalent

tests/compile/test_decorator.py

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from vllm.compilation.decorators import (ignore_torch_compile,
99
support_torch_compile)
1010
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
11-
VllmConfig, set_current_vllm_config)
12-
from vllm.forward_context import set_forward_context
11+
CUDAGraphMode, VllmConfig, set_current_vllm_config)
12+
from vllm.forward_context import BatchDescriptor, set_forward_context
1313
from vllm.utils import direct_register_custom_op
1414

1515
# create a library to hold the custom op
@@ -40,6 +40,39 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
4040
)
4141

4242

43+
@torch.inference_mode
44+
def run_model(vllm_config: VllmConfig, model: nn.Module,
45+
cudagraph_runtime_mode: CUDAGraphMode):
46+
with set_forward_context({}, vllm_config=vllm_config):
47+
# warmup for the model with cudagraph_mode NONE
48+
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
49+
50+
# simulate cudagraphs capturing
51+
with set_forward_context({},
52+
vllm_config=vllm_config,
53+
cudagraph_runtime_mode=cudagraph_runtime_mode,
54+
batch_descriptor=BatchDescriptor(
55+
num_tokens=2, )):
56+
model(torch.randn(2, MLP_SIZE).cuda())
57+
with set_forward_context({},
58+
vllm_config=vllm_config,
59+
cudagraph_runtime_mode=cudagraph_runtime_mode,
60+
batch_descriptor=BatchDescriptor(
61+
num_tokens=1, )):
62+
model(torch.randn(1, MLP_SIZE).cuda())
63+
64+
# simulate cudagraphs replay
65+
with set_forward_context({},
66+
vllm_config=vllm_config,
67+
cudagraph_runtime_mode=cudagraph_runtime_mode,
68+
batch_descriptor=BatchDescriptor(
69+
num_tokens=2, )):
70+
output = model(torch.randn(2, MLP_SIZE).cuda())
71+
72+
output = output.cpu()
73+
return output.cpu()
74+
75+
4376
def test_ignore_torch_compile_decorator():
4477
# piecewise
4578
vllm_config = VllmConfig(compilation_config=CompilationConfig(
@@ -48,6 +81,7 @@ def test_ignore_torch_compile_decorator():
4881
splitting_ops=["silly.attention"],
4982
cudagraph_capture_sizes=[1, 2],
5083
))
84+
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
5185

5286
@support_torch_compile
5387
class A(nn.Module):
@@ -86,12 +120,8 @@ class C(B):
86120
num_backend_compilations=2,
87121
num_cudagraph_captured=4,
88122
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
89-
), set_forward_context({}, vllm_config=vllm_config):
90-
# first run is for compile
91-
mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
92-
# run cudagraph captured sizes
93-
mod_A(torch.randn(2, MLP_SIZE).cuda())
94-
mod_A(torch.randn(1, MLP_SIZE).cuda())
123+
):
124+
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
95125

96126
with set_current_vllm_config(vllm_config):
97127
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
@@ -103,10 +133,8 @@ class C(B):
103133
num_piecewise_capturable_graphs_seen=0,
104134
num_backend_compilations=0,
105135
num_cudagraph_captured=0,
106-
), set_forward_context({}, vllm_config=vllm_config):
107-
mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
108-
mod_B(torch.randn(2, MLP_SIZE).cuda())
109-
mod_B(torch.randn(1, MLP_SIZE).cuda())
136+
):
137+
run_model(vllm_config, mod_B, cudagraph_runtime_mode)
110138

111139
with set_current_vllm_config(vllm_config):
112140
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
@@ -119,10 +147,8 @@ class C(B):
119147
num_backend_compilations=2,
120148
num_cudagraph_captured=4,
121149
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
122-
), set_forward_context({}, vllm_config=vllm_config):
123-
mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
124-
mod_C(torch.randn(2, MLP_SIZE).cuda())
125-
mod_C(torch.randn(1, MLP_SIZE).cuda())
150+
):
151+
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
126152

127153

128154
# Only enable torch.compile if
@@ -180,6 +206,7 @@ def test_conditional_compile_enable_if():
180206
splitting_ops=["silly.attention"],
181207
cudagraph_capture_sizes=[1, 2],
182208
))
209+
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
183210

184211
with set_current_vllm_config(vllm_config):
185212
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
@@ -195,12 +222,8 @@ def test_conditional_compile_enable_if():
195222
num_backend_compilations=4,
196223
num_cudagraph_captured=8,
197224
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
198-
), set_forward_context({}, vllm_config=vllm_config):
199-
# first run is for compile
200-
mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
201-
# run cudagraph captured sizes
202-
mod_A(torch.randn(2, MLP_SIZE).cuda())
203-
mod_A(torch.randn(1, MLP_SIZE).cuda())
225+
):
226+
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
204227

205228
# Set kv_sharing_fast_prefill=False
206229
# which will cause A to be compiled and B to not be compiled
@@ -224,9 +247,5 @@ def test_conditional_compile_enable_if():
224247
num_backend_compilations=4,
225248
num_cudagraph_captured=8,
226249
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
227-
), set_forward_context({}, vllm_config=vllm_config):
228-
# first run is for compile
229-
mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
230-
# run cudagraph captured sizes
231-
mod_A(torch.randn(2, MLP_SIZE).cuda())
232-
mod_A(torch.randn(1, MLP_SIZE).cuda())
250+
):
251+
run_model(vllm_config, mod_A, cudagraph_runtime_mode)

0 commit comments

Comments
 (0)