Skip to content

Commit 4ac7713

Browse files
authored
Add test case for compiling multiple graphs (#21044)
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent 8560a5b commit 4ac7713

File tree

3 files changed

+390
-1
lines changed

3 files changed

+390
-1
lines changed
Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Test (piecewise) compilation with a simple model where multiple submodules
5+
are compiled and graph captured separately.
6+
"""
7+
import torch
8+
from torch import nn
9+
from torch.library import Library
10+
11+
from vllm.compilation.backends import set_model_tag
12+
from vllm.compilation.counter import compilation_counter
13+
from vllm.compilation.decorators import (ignore_torch_compile,
14+
support_torch_compile)
15+
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
16+
set_current_vllm_config)
17+
from vllm.envs import VLLM_USE_V1
18+
from vllm.forward_context import set_forward_context
19+
from vllm.utils import direct_register_custom_op
20+
21+
# create a library to hold the custom op
22+
silly_lib = Library("silly", "FRAGMENT") # noqa
23+
24+
BATCH_SIZE = 32
25+
MLP_SIZE = 128
26+
HIDDEN_SIZE = 1024
27+
RANDOM_SEED = 0
28+
29+
30+
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
31+
out: torch.Tensor) -> None:
32+
out.copy_(q)
33+
out += k
34+
out += v
35+
36+
37+
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
38+
out: torch.Tensor) -> None:
39+
return
40+
41+
42+
direct_register_custom_op(
43+
op_name="attention",
44+
op_func=silly_attention,
45+
mutates_args=["out"],
46+
fake_impl=silly_attention_fake,
47+
target_lib=silly_lib,
48+
)
49+
50+
51+
@support_torch_compile
52+
class ParentModel(nn.Module):
53+
54+
def __init__(self,
55+
*,
56+
vllm_config: VllmConfig,
57+
prefix: str = '',
58+
**kwargs) -> None:
59+
super().__init__()
60+
61+
def forward(self, x: torch.Tensor) -> torch.Tensor:
62+
return x
63+
64+
65+
class Attention(nn.Module):
66+
67+
def __init__(self, mlp_size: int, hidden_size: int) -> None:
68+
super().__init__()
69+
self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False)
70+
self.post_attn = nn.Linear(hidden_size, mlp_size, bias=False)
71+
self.rms_norm_weight = nn.Parameter(torch.ones(hidden_size))
72+
73+
# Initialize to same weights for testing
74+
nn.init.xavier_normal_(
75+
self.pre_attn.weight.data,
76+
generator=torch.Generator().manual_seed(RANDOM_SEED),
77+
gain=0.001)
78+
nn.init.xavier_normal_(
79+
self.post_attn.weight.data,
80+
generator=torch.Generator().manual_seed(RANDOM_SEED),
81+
gain=0.001)
82+
83+
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
84+
x_f32 = x.float()
85+
return (x_f32 * torch.rsqrt(
86+
torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) *
87+
self.rms_norm_weight).to(x.dtype)
88+
89+
def forward(self, x: torch.Tensor) -> torch.Tensor:
90+
x = self.pre_attn(x)
91+
x = self.rms_norm_ref(x)
92+
attn_output = torch.empty_like(x)
93+
torch.ops.silly.attention(x, x, x, attn_output)
94+
x = attn_output
95+
x = self.rms_norm_ref(x)
96+
x = self.post_attn(x)
97+
return x
98+
99+
100+
@support_torch_compile
101+
class CompiledAttention(nn.Module):
102+
103+
def __init__(self,
104+
*,
105+
mlp_size: int,
106+
hidden_size: int,
107+
vllm_config: VllmConfig,
108+
prefix: str = '',
109+
**kwargs) -> None:
110+
super().__init__()
111+
self.attn = Attention(mlp_size, hidden_size)
112+
113+
def forward(self, x: torch.Tensor) -> torch.Tensor:
114+
return self.attn(x)
115+
116+
117+
@support_torch_compile
118+
class CompiledAttentionTwo(CompiledAttention):
119+
120+
def forward(self, x: torch.Tensor) -> torch.Tensor:
121+
return self.attn(x) + x
122+
123+
124+
@ignore_torch_compile
125+
class SimpleModelWithTwoGraphs(ParentModel):
126+
127+
def __init__(self,
128+
*,
129+
mlp_size: int,
130+
hidden_size: int,
131+
vllm_config: VllmConfig,
132+
prefix: str = '',
133+
**kwargs) -> None:
134+
super().__init__(vllm_config=vllm_config, prefix=prefix)
135+
# Test will fail without set_model_tag here with error:
136+
# "ValueError: too many values to unpack (expected 3)"
137+
# This is because CompiledAttention and CompiledAttentionTwo
138+
# have different implmentations but the same torch.compile
139+
# cache dir will be used as default prefix is 'model_tag'
140+
with set_model_tag("attn_one"):
141+
self.attn_one = CompiledAttention(
142+
mlp_size=mlp_size,
143+
hidden_size=hidden_size,
144+
vllm_config=vllm_config,
145+
prefix=f"{prefix}.attn_one",
146+
)
147+
with set_model_tag("attn_two"):
148+
self.attn_two = CompiledAttentionTwo(
149+
mlp_size=mlp_size,
150+
hidden_size=hidden_size,
151+
vllm_config=vllm_config,
152+
prefix=f"{prefix}.attn_two",
153+
)
154+
155+
self.hidden_states = torch.zeros((BATCH_SIZE, MLP_SIZE)).cuda()
156+
157+
def forward(self, x: torch.Tensor) -> torch.Tensor:
158+
bsz = x.shape[0]
159+
# CUDAGraph expects same tensor addresses for each run
160+
self.hidden_states[:bsz].copy_(x)
161+
x = self.attn_one(self.hidden_states[:bsz])
162+
self.hidden_states[:bsz].copy_(x)
163+
x = self.attn_two(self.hidden_states[:bsz])
164+
return x
165+
166+
167+
def test_ignore_torch_compile_decorator():
168+
assert VLLM_USE_V1
169+
170+
# piecewise
171+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
172+
level=CompilationLevel.PIECEWISE,
173+
use_cudagraph=True,
174+
splitting_ops=["silly.attention"],
175+
cudagraph_capture_sizes=[1, 2],
176+
))
177+
178+
@support_torch_compile
179+
class A(nn.Module):
180+
181+
def __init__(self,
182+
*,
183+
vllm_config: VllmConfig,
184+
prefix: str = '',
185+
**kwargs) -> None:
186+
super().__init__()
187+
188+
def forward(self, x: torch.Tensor) -> torch.Tensor:
189+
x = x + x
190+
attn_output = torch.empty_like(x)
191+
torch.ops.silly.attention(x, x, x, attn_output)
192+
x = attn_output
193+
x = x * 3
194+
return x
195+
196+
@ignore_torch_compile
197+
class B(A):
198+
...
199+
200+
@support_torch_compile
201+
class C(B):
202+
...
203+
204+
with set_current_vllm_config(vllm_config):
205+
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
206+
207+
# A has support_torch_compile
208+
with compilation_counter.expect(
209+
num_graphs_seen=1,
210+
num_piecewise_graphs_seen=3,
211+
num_piecewise_capturable_graphs_seen=2,
212+
num_backend_compilations=2,
213+
num_cudagraph_captured=4,
214+
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
215+
), set_forward_context({}, vllm_config=vllm_config):
216+
# first run is for compile
217+
mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
218+
# run cudagraph captured sizes
219+
mod_A(torch.randn(2, MLP_SIZE).cuda())
220+
mod_A(torch.randn(1, MLP_SIZE).cuda())
221+
222+
with set_current_vllm_config(vllm_config):
223+
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
224+
225+
# B's ignore_torch_compile should override A's support_torch_compile
226+
with compilation_counter.expect(
227+
num_graphs_seen=0,
228+
num_piecewise_graphs_seen=0,
229+
num_piecewise_capturable_graphs_seen=0,
230+
num_backend_compilations=0,
231+
num_cudagraph_captured=0,
232+
), set_forward_context({}, vllm_config=vllm_config):
233+
mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
234+
mod_B(torch.randn(2, MLP_SIZE).cuda())
235+
mod_B(torch.randn(1, MLP_SIZE).cuda())
236+
237+
with set_current_vllm_config(vllm_config):
238+
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
239+
240+
# C's support_torch_compile should override B's ignore_torch_compile
241+
with compilation_counter.expect(
242+
num_graphs_seen=1,
243+
num_piecewise_graphs_seen=3,
244+
num_piecewise_capturable_graphs_seen=2,
245+
num_backend_compilations=2,
246+
num_cudagraph_captured=4,
247+
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
248+
), set_forward_context({}, vllm_config=vllm_config):
249+
mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
250+
mod_C(torch.randn(2, MLP_SIZE).cuda())
251+
mod_C(torch.randn(1, MLP_SIZE).cuda())
252+
253+
254+
@torch.inference_mode
255+
def run_model(vllm_config, model: nn.Module, inputs: torch.Tensor):
256+
with set_forward_context({}, vllm_config=vllm_config):
257+
# First run is for compile
258+
model(inputs)
259+
260+
# Run CUDAGraph captured sizes
261+
model(inputs[:2])
262+
model(inputs[:1])
263+
264+
output = model(inputs[:2])
265+
266+
output = output.cpu()
267+
return output.cpu()
268+
269+
270+
def test_multi_graph_piecewise_compile_outputs_equal():
271+
outputs = []
272+
273+
# piecewise compile
274+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
275+
level=CompilationLevel.PIECEWISE,
276+
use_cudagraph=True,
277+
splitting_ops=["silly.attention"],
278+
cudagraph_capture_sizes=[1, 2],
279+
))
280+
281+
with set_current_vllm_config(vllm_config):
282+
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
283+
hidden_size=HIDDEN_SIZE,
284+
vllm_config=vllm_config,
285+
prefix='').eval().cuda()
286+
287+
# Pre-allocate memory for CUDAGraph which expects
288+
# static tensor addresses
289+
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
290+
291+
with compilation_counter.expect(
292+
num_graphs_seen=2, # two graphs for the model
293+
num_piecewise_graphs_seen=6,
294+
# attn_one, attn_two each has 3 piecewise graphs
295+
# (pre attn, post attn, silly_attention) each
296+
num_piecewise_capturable_graphs_seen=4,
297+
# attn_one, attn_two has pre attn and post attn each, total=4
298+
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
299+
num_cudagraph_captured=8,
300+
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
301+
):
302+
outputs.append(run_model(vllm_config, model, inputs))
303+
304+
# no compile or cudagraph
305+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
306+
level=CompilationLevel.NO_COMPILATION, ))
307+
308+
with set_current_vllm_config(vllm_config):
309+
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
310+
hidden_size=HIDDEN_SIZE,
311+
vllm_config=vllm_config,
312+
prefix='').eval().cuda()
313+
314+
with compilation_counter.expect(
315+
num_graphs_seen=0,
316+
num_piecewise_graphs_seen=0,
317+
num_piecewise_capturable_graphs_seen=0,
318+
num_backend_compilations=0,
319+
num_cudagraph_captured=0,
320+
):
321+
outputs.append(run_model(vllm_config, model, inputs))
322+
323+
# piecewise compile without CUDA graph
324+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
325+
level=CompilationLevel.PIECEWISE,
326+
use_cudagraph=False,
327+
splitting_ops=["silly.attention"],
328+
))
329+
330+
with set_current_vllm_config(vllm_config):
331+
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
332+
hidden_size=HIDDEN_SIZE,
333+
vllm_config=vllm_config,
334+
prefix='').eval().cuda()
335+
336+
with compilation_counter.expect(
337+
num_graphs_seen=2,
338+
num_piecewise_graphs_seen=6,
339+
num_piecewise_capturable_graphs_seen=4,
340+
num_backend_compilations=4,
341+
num_cudagraph_captured=0, # no cudagraph captured
342+
):
343+
outputs.append(run_model(vllm_config, model, inputs))
344+
345+
# Generally don't expect outputs with and without inductor
346+
# to be bitwise equivalent
347+
assert torch.allclose(outputs[0], outputs[1])
348+
349+
# Expect bitwise equivalence using inductor w/ and w/o cudagraph
350+
assert torch.equal(outputs[0], outputs[2])

vllm/compilation/compiler_interface.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,12 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
423423
if is_torch_equal_or_newer("2.6"):
424424
stack.enter_context(
425425
torch._inductor.config.patch(fx_graph_remote_cache=False))
426+
# InductorAdaptor (unfortunately) requires AOTAutogradCache
427+
# to be turned off to run. It will fail to acquire the hash_str
428+
# and error if not.
429+
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
430+
stack.enter_context(
431+
torch._functorch.config.patch(enable_autograd_cache=False))
426432
stack.enter_context(
427433
torch._functorch.config.patch(
428434
enable_remote_autograd_cache=False))

0 commit comments

Comments
 (0)