Skip to content

Commit eaec4b9

Browse files
tdoublepcyang49
andauthored
[Bugfix] Add custom Triton cache manager to resolve MoE MP issue (#6140)
Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Chih-Chieh-Yang <[email protected]>
1 parent a63a4c6 commit eaec4b9

File tree

3 files changed

+64
-0
lines changed

3 files changed

+64
-0
lines changed

vllm/executor/multiproc_gpu_executor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
ResultHandler, WorkerMonitor)
1010
from vllm.logger import init_logger
1111
from vllm.sequence import ExecuteModelRequest, SamplerOutput
12+
from vllm.triton_utils import maybe_set_triton_cache_manager
1213
from vllm.utils import (cuda_device_count_stateless,
1314
error_on_invalid_device_count_status,
1415
get_distributed_init_method, get_open_port,
@@ -42,6 +43,10 @@ def _init_executor(self) -> None:
4243
if "OMP_NUM_THREADS" not in os.environ:
4344
os.environ["OMP_NUM_THREADS"] = "1"
4445

46+
# workaround for https://github.com/vllm-project/vllm/issues/6103
47+
if world_size > 1:
48+
maybe_set_triton_cache_manager()
49+
4550
assert world_size <= cuda_device_count_stateless(), (
4651
"please set tensor_parallel_size to less than max local gpu count")
4752

vllm/triton_utils/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from vllm.triton_utils.custom_cache_manager import (
2+
maybe_set_triton_cache_manager)
3+
4+
__all__ = [
5+
"maybe_set_triton_cache_manager",
6+
]
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import os
2+
3+
from triton.runtime.cache import (FileCacheManager, default_cache_dir,
4+
default_dump_dir, default_override_dir)
5+
6+
from vllm.logger import init_logger
7+
8+
logger = init_logger(__name__)
9+
10+
11+
def maybe_set_triton_cache_manager() -> None:
12+
"""Set environment variable to tell Triton to use a
13+
custom cache manager"""
14+
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
15+
if cache_manger is None:
16+
manager = "vllm.triton_utils.custom_cache_manager:CustomCacheManager"
17+
logger.info("Setting Triton cache manager to: %s", manager)
18+
os.environ["TRITON_CACHE_MANAGER"] = manager
19+
20+
21+
class CustomCacheManager(FileCacheManager):
22+
"""Re-implements Triton's cache manager, ensuring that a
23+
unique cache directory is created for each process. This is
24+
needed to avoid collisions when running with tp>1 and
25+
using multi-processing as the distributed backend.
26+
27+
Note this issue was fixed by triton-lang/triton/pull/4295,
28+
but the fix is not yet included in triton==v3.0.0. However,
29+
it should be included in the subsequent version.
30+
"""
31+
32+
def __init__(self, key, override=False, dump=False):
33+
self.key = key
34+
self.lock_path = None
35+
if dump:
36+
self.cache_dir = default_dump_dir()
37+
self.cache_dir = os.path.join(self.cache_dir, self.key)
38+
self.lock_path = os.path.join(self.cache_dir, "lock")
39+
os.makedirs(self.cache_dir, exist_ok=True)
40+
elif override:
41+
self.cache_dir = default_override_dir()
42+
self.cache_dir = os.path.join(self.cache_dir, self.key)
43+
else:
44+
# create cache directory if it doesn't exist
45+
self.cache_dir = os.getenv("TRITON_CACHE_DIR",
46+
"").strip() or default_cache_dir()
47+
if self.cache_dir:
48+
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
49+
self.cache_dir = os.path.join(self.cache_dir, self.key)
50+
self.lock_path = os.path.join(self.cache_dir, "lock")
51+
os.makedirs(self.cache_dir, exist_ok=True)
52+
else:
53+
raise RuntimeError("Could not create or locate cache dir")

0 commit comments

Comments
 (0)