Skip to content

Commit 4b3e447

Browse files
authored
Optimize configuration access with LRU cache in custom ops (#22204)
Signed-off-by: zitian zhao <[email protected]>
1 parent bd3db7f commit 4b3e447

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

vllm/config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from contextlib import contextmanager
1616
from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
1717
replace)
18-
from functools import cached_property
18+
from functools import cached_property, lru_cache
1919
from importlib.util import find_spec
2020
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
2121
Protocol, TypeVar, Union, cast, get_args)
@@ -5123,6 +5123,14 @@ def set_current_vllm_config(vllm_config: VllmConfig,
51235123
finally:
51245124
_current_vllm_config = old_vllm_config
51255125
_current_prefix = old_prefix
5126+
# Clear the compilation config cache when context changes
5127+
get_cached_compilation_config.cache_clear()
5128+
5129+
5130+
@lru_cache(maxsize=1)
5131+
def get_cached_compilation_config():
5132+
"""Cache config to avoid repeated calls to get_current_vllm_config()"""
5133+
return get_current_vllm_config().compilation_config
51265134

51275135

51285136
def get_current_vllm_config() -> VllmConfig:

vllm/model_executor/custom_op.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch.nn as nn
77

8-
from vllm.config import get_current_vllm_config
8+
from vllm.config import get_cached_compilation_config
99
from vllm.logger import init_logger
1010
from vllm.platforms import current_platform
1111

@@ -86,7 +86,7 @@ def forward_oot(self, *args, **kwargs):
8686
def dispatch_forward(self):
8787
# NOTE(woosuk): Here we assume that vLLM was built for only one
8888
# specific backend. Currently, we do not support dynamic dispatching.
89-
compilation_config = get_current_vllm_config().compilation_config
89+
compilation_config = get_cached_compilation_config()
9090
enabled = self.enabled()
9191
if enabled:
9292
compilation_config.enabled_custom_ops.update([self.__class__.name])
@@ -115,7 +115,7 @@ def dispatch_forward(self):
115115
@classmethod
116116
def enabled(cls) -> bool:
117117
# if no name, then it was not registered
118-
compilation_config = get_current_vllm_config().compilation_config
118+
compilation_config = get_cached_compilation_config()
119119
custom_ops = compilation_config.custom_ops
120120
if not hasattr(cls, "name"):
121121
logger.warning_once(
@@ -138,7 +138,7 @@ def default_on() -> bool:
138138
Specifying 'all' or 'none' in custom_op takes precedence.
139139
"""
140140
from vllm.config import CompilationLevel
141-
compilation_config = get_current_vllm_config().compilation_config
141+
compilation_config = get_cached_compilation_config()
142142
default_on = (compilation_config.level < CompilationLevel.PIECEWISE
143143
or not compilation_config.use_inductor)
144144
count_none = compilation_config.custom_ops.count("none")

0 commit comments

Comments
 (0)