Skip to content

Commit bd849d6

Browse files
committed
Rename compile_cond -> enable_if
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent 9f79f23 commit bd849d6

File tree

2 files changed

+13
-14
lines changed

2 files changed

+13
-14
lines changed

tests/compile/test_decorator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ class C(B):
127127

128128
# Only enable torch.compile if
129129
# vllm_config.cache_config.kv_sharing_fast_prefill=True
130-
@support_torch_compile(compile_cond=lambda vllm_config: vllm_config.
131-
cache_config.kv_sharing_fast_prefill)
130+
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
131+
kv_sharing_fast_prefill)
132132
class B(nn.Module):
133133

134134
def __init__(self,
@@ -149,7 +149,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
149149

150150
# Only enable torch.compile if
151151
# vllm_config.cache_config.kv_sharing_fast_prefill=False
152-
@support_torch_compile(compile_cond=lambda vllm_config: not vllm_config.
152+
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
153153
cache_config.kv_sharing_fast_prefill)
154154
class A(nn.Module):
155155

@@ -171,7 +171,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
171171
return x
172172

173173

174-
def test_support_torch_compile_cond():
174+
def test_conditional_compile_enable_if():
175175
vllm_config = VllmConfig(cache_config=CacheConfig(
176176
kv_sharing_fast_prefill=True, ),
177177
compilation_config=CompilationConfig(
@@ -184,8 +184,8 @@ def test_support_torch_compile_cond():
184184
with set_current_vllm_config(vllm_config):
185185
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
186186

187-
# A has support_torch_compile but compile_cond is not satisified
188-
# compile_cond will be satisified for B, so we expect mod1 and mod2
187+
# A has support_torch_compile but enable_if fn returns False
188+
# enalbe_if will be True for B, so we expect mod1 and mod2
189189
# to be compiled
190190
with compilation_counter.expect(
191191
num_graphs_seen=2,

vllm/compilation/decorators.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _should_ignore_torch_compile(cls) -> bool:
5555
@overload
5656
def support_torch_compile(
5757
*,
58-
compile_cond: Optional[Callable[[VllmConfig], bool]] = None,
58+
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
5959
) -> Callable[[_T], _T]:
6060
...
6161

@@ -77,7 +77,7 @@ def support_torch_compile(
7777
cls: Optional[_T] = None,
7878
*,
7979
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
80-
compile_cond: Optional[Callable[[VllmConfig], bool]] = None,
80+
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
8181
) -> Union[Callable[[_T], _T], _T]:
8282
"""
8383
A decorator to add support for compiling the forward method of a class.
@@ -128,7 +128,7 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
128128
the lifetime of the model, otherwise, it cannot be captured as a single
129129
computation graph.
130130
131-
`compile_cond` is a function that takes a `VllmConfig` object as input and
131+
`enable_if` is a function that takes a `VllmConfig` object as input and
132132
returns a boolean value indicating whether to compile the model or not.
133133
This is useful if you want to compile the model only when certain
134134
conditions are met.
@@ -164,7 +164,7 @@ def cls_decorator_helper(cls: _T) -> _T:
164164
raise ValueError(
165165
f"Argument {k} not found in the forward method of {cls}")
166166
return _support_torch_compile(cls, inferred_dynamic_arg_dims,
167-
compile_cond)
167+
enable_if)
168168

169169
if cls is not None:
170170
# use `support_torch_compile` as a decorator without arguments
@@ -177,7 +177,7 @@ def cls_decorator_helper(cls: _T) -> _T:
177177
def _support_torch_compile(
178178
cls: _T,
179179
dynamic_arg_dims: dict[str, Union[int, list[int]]],
180-
compile_cond: Optional[Callable[[VllmConfig], bool]] = None,
180+
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
181181
) -> _T:
182182
"""
183183
A decorator to add support for compiling the forward method of a class.
@@ -198,15 +198,14 @@ def _support_torch_compile(
198198
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
199199
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
200200
self.vllm_config = vllm_config
201-
compile_cond_satisfied = compile_cond is None or compile_cond(
202-
vllm_config)
201+
enable_compile = enable_if is None or enable_if(vllm_config)
203202
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
204203
# will handle the compilation, so we don't need to do anything here.
205204
self.do_not_compile = \
206205
vllm_config.compilation_config.level in [
207206
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
208207
] or not supports_dynamo() or _should_ignore_torch_compile(
209-
self.__class__) or not compile_cond_satisfied
208+
self.__class__) or not enable_compile
210209
if self.do_not_compile:
211210
return
212211

0 commit comments

Comments
 (0)