@@ -55,7 +55,7 @@ def _should_ignore_torch_compile(cls) -> bool:
55
55
@overload
56
56
def support_torch_compile (
57
57
* ,
58
- compile_cond : Optional [Callable [[VllmConfig ], bool ]] = None ,
58
+ enable_if : Optional [Callable [[VllmConfig ], bool ]] = None ,
59
59
) -> Callable [[_T ], _T ]:
60
60
...
61
61
@@ -77,7 +77,7 @@ def support_torch_compile(
77
77
cls : Optional [_T ] = None ,
78
78
* ,
79
79
dynamic_arg_dims : Optional [dict [str , Union [int , list [int ]]]] = None ,
80
- compile_cond : Optional [Callable [[VllmConfig ], bool ]] = None ,
80
+ enable_if : Optional [Callable [[VllmConfig ], bool ]] = None ,
81
81
) -> Union [Callable [[_T ], _T ], _T ]:
82
82
"""
83
83
A decorator to add support for compiling the forward method of a class.
@@ -128,7 +128,7 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
128
128
the lifetime of the model, otherwise, it cannot be captured as a single
129
129
computation graph.
130
130
131
- `compile_cond ` is a function that takes a `VllmConfig` object as input and
131
+ `enable_if ` is a function that takes a `VllmConfig` object as input and
132
132
returns a boolean value indicating whether to compile the model or not.
133
133
This is useful if you want to compile the model only when certain
134
134
conditions are met.
@@ -164,7 +164,7 @@ def cls_decorator_helper(cls: _T) -> _T:
164
164
raise ValueError (
165
165
f"Argument { k } not found in the forward method of { cls } " )
166
166
return _support_torch_compile (cls , inferred_dynamic_arg_dims ,
167
- compile_cond )
167
+ enable_if )
168
168
169
169
if cls is not None :
170
170
# use `support_torch_compile` as a decorator without arguments
@@ -177,7 +177,7 @@ def cls_decorator_helper(cls: _T) -> _T:
177
177
def _support_torch_compile (
178
178
cls : _T ,
179
179
dynamic_arg_dims : dict [str , Union [int , list [int ]]],
180
- compile_cond : Optional [Callable [[VllmConfig ], bool ]] = None ,
180
+ enable_if : Optional [Callable [[VllmConfig ], bool ]] = None ,
181
181
) -> _T :
182
182
"""
183
183
A decorator to add support for compiling the forward method of a class.
@@ -198,15 +198,14 @@ def _support_torch_compile(
198
198
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = '' , ** kwargs ):
199
199
old_init (self , vllm_config = vllm_config , prefix = prefix , ** kwargs )
200
200
self .vllm_config = vllm_config
201
- compile_cond_satisfied = compile_cond is None or compile_cond (
202
- vllm_config )
201
+ enable_compile = enable_if is None or enable_if (vllm_config )
203
202
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
204
203
# will handle the compilation, so we don't need to do anything here.
205
204
self .do_not_compile = \
206
205
vllm_config .compilation_config .level in [
207
206
CompilationLevel .NO_COMPILATION , CompilationLevel .DYNAMO_AS_IS
208
207
] or not supports_dynamo () or _should_ignore_torch_compile (
209
- self .__class__ ) or not compile_cond_satisfied
208
+ self .__class__ ) or not enable_compile
210
209
if self .do_not_compile :
211
210
return
212
211
0 commit comments