Skip to content

Commit d92ae04

Browse files
committed
Support conditional torch.compile
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent 1d20c34 commit d92ae04

File tree

3 files changed

+252
-90
lines changed

3 files changed

+252
-90
lines changed

tests/compile/piecewise/test_multiple_graphs.py

Lines changed: 0 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
support_torch_compile)
1515
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
1616
set_current_vllm_config)
17-
from vllm.envs import VLLM_USE_V1
1817
from vllm.forward_context import set_forward_context
1918
from vllm.utils import direct_register_custom_op
2019

@@ -164,93 +163,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
164163
return x
165164

166165

167-
def test_ignore_torch_compile_decorator():
168-
assert VLLM_USE_V1
169-
170-
# piecewise
171-
vllm_config = VllmConfig(compilation_config=CompilationConfig(
172-
level=CompilationLevel.PIECEWISE,
173-
use_cudagraph=True,
174-
splitting_ops=["silly.attention"],
175-
cudagraph_capture_sizes=[1, 2],
176-
))
177-
178-
@support_torch_compile
179-
class A(nn.Module):
180-
181-
def __init__(self,
182-
*,
183-
vllm_config: VllmConfig,
184-
prefix: str = '',
185-
**kwargs) -> None:
186-
super().__init__()
187-
188-
def forward(self, x: torch.Tensor) -> torch.Tensor:
189-
x = x + x
190-
attn_output = torch.empty_like(x)
191-
torch.ops.silly.attention(x, x, x, attn_output)
192-
x = attn_output
193-
x = x * 3
194-
return x
195-
196-
@ignore_torch_compile
197-
class B(A):
198-
...
199-
200-
@support_torch_compile
201-
class C(B):
202-
...
203-
204-
with set_current_vllm_config(vllm_config):
205-
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
206-
207-
# A has support_torch_compile
208-
with compilation_counter.expect(
209-
num_graphs_seen=1,
210-
num_piecewise_graphs_seen=3,
211-
num_piecewise_capturable_graphs_seen=2,
212-
num_backend_compilations=2,
213-
num_cudagraph_captured=4,
214-
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
215-
), set_forward_context({}, vllm_config=vllm_config):
216-
# first run is for compile
217-
mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
218-
# run cudagraph captured sizes
219-
mod_A(torch.randn(2, MLP_SIZE).cuda())
220-
mod_A(torch.randn(1, MLP_SIZE).cuda())
221-
222-
with set_current_vllm_config(vllm_config):
223-
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
224-
225-
# B's ignore_torch_compile should override A's support_torch_compile
226-
with compilation_counter.expect(
227-
num_graphs_seen=0,
228-
num_piecewise_graphs_seen=0,
229-
num_piecewise_capturable_graphs_seen=0,
230-
num_backend_compilations=0,
231-
num_cudagraph_captured=0,
232-
), set_forward_context({}, vllm_config=vllm_config):
233-
mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
234-
mod_B(torch.randn(2, MLP_SIZE).cuda())
235-
mod_B(torch.randn(1, MLP_SIZE).cuda())
236-
237-
with set_current_vllm_config(vllm_config):
238-
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
239-
240-
# C's support_torch_compile should override B's ignore_torch_compile
241-
with compilation_counter.expect(
242-
num_graphs_seen=1,
243-
num_piecewise_graphs_seen=3,
244-
num_piecewise_capturable_graphs_seen=2,
245-
num_backend_compilations=2,
246-
num_cudagraph_captured=4,
247-
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
248-
), set_forward_context({}, vllm_config=vllm_config):
249-
mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
250-
mod_C(torch.randn(2, MLP_SIZE).cuda())
251-
mod_C(torch.randn(1, MLP_SIZE).cuda())
252-
253-
254166
@torch.inference_mode
255167
def run_model(vllm_config, model: nn.Module, inputs: torch.Tensor):
256168
with set_forward_context({}, vllm_config=vllm_config):

tests/compile/test_decorator.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import torch
4+
from torch import nn
5+
from torch.library import Library
6+
7+
from vllm.compilation.counter import compilation_counter
8+
from vllm.compilation.decorators import (ignore_torch_compile,
9+
support_torch_compile)
10+
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
11+
VllmConfig, set_current_vllm_config)
12+
from vllm.forward_context import set_forward_context
13+
from vllm.utils import direct_register_custom_op
14+
15+
# create a library to hold the custom op
16+
silly_lib = Library("silly", "FRAGMENT") # noqa
17+
18+
BATCH_SIZE = 32
19+
MLP_SIZE = 128
20+
21+
22+
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
23+
out: torch.Tensor) -> None:
24+
out.copy_(q)
25+
out += k
26+
out += v
27+
28+
29+
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
30+
out: torch.Tensor) -> None:
31+
return
32+
33+
34+
direct_register_custom_op(
35+
op_name="attention",
36+
op_func=silly_attention,
37+
mutates_args=["out"],
38+
fake_impl=silly_attention_fake,
39+
target_lib=silly_lib,
40+
)
41+
42+
43+
def test_ignore_torch_compile_decorator():
44+
# piecewise
45+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
46+
level=CompilationLevel.PIECEWISE,
47+
use_cudagraph=True,
48+
splitting_ops=["silly.attention"],
49+
cudagraph_capture_sizes=[1, 2],
50+
))
51+
52+
@support_torch_compile
53+
class A(nn.Module):
54+
55+
def __init__(self,
56+
*,
57+
vllm_config: VllmConfig,
58+
prefix: str = '',
59+
**kwargs) -> None:
60+
super().__init__()
61+
62+
def forward(self, x: torch.Tensor) -> torch.Tensor:
63+
x = x + x
64+
attn_output = torch.empty_like(x)
65+
torch.ops.silly.attention(x, x, x, attn_output)
66+
x = attn_output
67+
x = x * 3
68+
return x
69+
70+
@ignore_torch_compile
71+
class B(A):
72+
...
73+
74+
@support_torch_compile
75+
class C(B):
76+
...
77+
78+
with set_current_vllm_config(vllm_config):
79+
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
80+
81+
# A has support_torch_compile
82+
with compilation_counter.expect(
83+
num_graphs_seen=1,
84+
num_piecewise_graphs_seen=3,
85+
num_piecewise_capturable_graphs_seen=2,
86+
num_backend_compilations=2,
87+
num_cudagraph_captured=4,
88+
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
89+
), set_forward_context({}, vllm_config=vllm_config):
90+
# first run is for compile
91+
mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
92+
# run cudagraph captured sizes
93+
mod_A(torch.randn(2, MLP_SIZE).cuda())
94+
mod_A(torch.randn(1, MLP_SIZE).cuda())
95+
96+
with set_current_vllm_config(vllm_config):
97+
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
98+
99+
# B's ignore_torch_compile should override A's support_torch_compile
100+
with compilation_counter.expect(
101+
num_graphs_seen=0,
102+
num_piecewise_graphs_seen=0,
103+
num_piecewise_capturable_graphs_seen=0,
104+
num_backend_compilations=0,
105+
num_cudagraph_captured=0,
106+
), set_forward_context({}, vllm_config=vllm_config):
107+
mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
108+
mod_B(torch.randn(2, MLP_SIZE).cuda())
109+
mod_B(torch.randn(1, MLP_SIZE).cuda())
110+
111+
with set_current_vllm_config(vllm_config):
112+
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
113+
114+
# C's support_torch_compile should override B's ignore_torch_compile
115+
with compilation_counter.expect(
116+
num_graphs_seen=1,
117+
num_piecewise_graphs_seen=3,
118+
num_piecewise_capturable_graphs_seen=2,
119+
num_backend_compilations=2,
120+
num_cudagraph_captured=4,
121+
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
122+
), set_forward_context({}, vllm_config=vllm_config):
123+
mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
124+
mod_C(torch.randn(2, MLP_SIZE).cuda())
125+
mod_C(torch.randn(1, MLP_SIZE).cuda())
126+
127+
128+
# Only enable torch.compile if
129+
# vllm_config.cache_config.kv_sharing_fast_prefill=True
130+
@support_torch_compile(compile_cond=lambda vllm_config: vllm_config.
131+
cache_config.kv_sharing_fast_prefill)
132+
class B(nn.Module):
133+
134+
def __init__(self,
135+
*,
136+
vllm_config: VllmConfig,
137+
prefix: str = '',
138+
**kwargs) -> None:
139+
super().__init__()
140+
141+
def forward(self, x: torch.Tensor) -> torch.Tensor:
142+
x = x + x
143+
attn_output = torch.empty_like(x)
144+
torch.ops.silly.attention(x, x, x, attn_output)
145+
x = attn_output
146+
x = x + x
147+
return x
148+
149+
150+
# Only enable torch.compile if
151+
# vllm_config.cache_config.kv_sharing_fast_prefill=False
152+
@support_torch_compile(compile_cond=lambda vllm_config: not vllm_config.
153+
cache_config.kv_sharing_fast_prefill)
154+
class A(nn.Module):
155+
156+
def __init__(self,
157+
*,
158+
vllm_config: VllmConfig,
159+
prefix: str = '',
160+
**kwargs) -> None:
161+
super().__init__()
162+
self.mod1 = B(vllm_config=vllm_config, perfix=prefix, **kwargs)
163+
self.mod2 = B(vllm_config=vllm_config, perfix=prefix, **kwargs)
164+
165+
def forward(self, x: torch.Tensor) -> torch.Tensor:
166+
x = self.mod1(x)
167+
attn_output = torch.empty_like(x)
168+
torch.ops.silly.attention(x, x, x, attn_output)
169+
x = attn_output
170+
x = self.mod2(x)
171+
return x
172+
173+
174+
def test_support_torch_compile_cond():
175+
vllm_config = VllmConfig(cache_config=CacheConfig(
176+
kv_sharing_fast_prefill=True, ),
177+
compilation_config=CompilationConfig(
178+
level=CompilationLevel.PIECEWISE,
179+
use_cudagraph=True,
180+
splitting_ops=["silly.attention"],
181+
cudagraph_capture_sizes=[1, 2],
182+
))
183+
184+
with set_current_vllm_config(vllm_config):
185+
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
186+
187+
# A has support_torch_compile but compile_cond is not satisified
188+
# compile_cond will be satisified for B, so we expect mod1 and mod2
189+
# to be compiled
190+
with compilation_counter.expect(
191+
num_graphs_seen=2,
192+
num_piecewise_graphs_seen=6,
193+
# 3 piecewise graphs per instance of B()
194+
num_piecewise_capturable_graphs_seen=4,
195+
num_backend_compilations=4,
196+
num_cudagraph_captured=8,
197+
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
198+
), set_forward_context({}, vllm_config=vllm_config):
199+
# first run is for compile
200+
mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
201+
# run cudagraph captured sizes
202+
mod_A(torch.randn(2, MLP_SIZE).cuda())
203+
mod_A(torch.randn(1, MLP_SIZE).cuda())
204+
205+
# Set kv_sharing_fast_prefill=False
206+
# which will cause A to be compiled and B to not be compiled
207+
vllm_config = VllmConfig(cache_config=CacheConfig(
208+
kv_sharing_fast_prefill=False, ),
209+
compilation_config=CompilationConfig(
210+
level=CompilationLevel.PIECEWISE,
211+
use_cudagraph=True,
212+
splitting_ops=["silly.attention"],
213+
cudagraph_capture_sizes=[1, 2],
214+
))
215+
216+
with set_current_vllm_config(vllm_config):
217+
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
218+
219+
with compilation_counter.expect(
220+
num_graphs_seen=1,
221+
num_piecewise_graphs_seen=7,
222+
# 3 attn ops and 4 non-attn ops
223+
num_piecewise_capturable_graphs_seen=4,
224+
num_backend_compilations=4,
225+
num_cudagraph_captured=8,
226+
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
227+
), set_forward_context({}, vllm_config=vllm_config):
228+
# first run is for compile
229+
mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
230+
# run cudagraph captured sizes
231+
mod_A(torch.randn(2, MLP_SIZE).cuda())
232+
mod_A(torch.randn(1, MLP_SIZE).cuda())

0 commit comments

Comments
 (0)