Skip to content

Commit 6fad29b

Browse files
Remove graph_pool as member of VllmBackend and argument to CUDAGraphWrapper (#23385)
Signed-off-by: Luka Govedič <[email protected]> Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: ProExpertProg <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
1 parent 6fd45e7 commit 6fad29b

File tree

3 files changed

+7
-20
lines changed

3 files changed

+7
-20
lines changed

vllm/compilation/backends.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -294,13 +294,12 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
294294

295295
def __init__(self, module: torch.fx.GraphModule,
296296
compile_submod_names: list[str], vllm_config: VllmConfig,
297-
graph_pool, vllm_backend: "VllmBackend"):
297+
vllm_backend: "VllmBackend"):
298298
super().__init__(module)
299299
from torch._guards import detect_fake_mode
300300
self.fake_mode = detect_fake_mode()
301301
self.compile_submod_names = compile_submod_names
302302
self.compilation_config = vllm_config.compilation_config
303-
self.graph_pool = graph_pool
304303
self.vllm_config = vllm_config
305304
self.vllm_backend = vllm_backend
306305
# When True, it annoyingly dumps the torch.fx.Graph on errors.
@@ -359,7 +358,6 @@ def call_module(self, target: torch.fx.node.Target,
359358
runnable=piecewise_backend,
360359
vllm_config=self.vllm_config,
361360
runtime_mode=CUDAGraphMode.PIECEWISE,
362-
graph_pool=self.graph_pool,
363361
cudagraph_options=CUDAGraphOptions(
364362
debug_log_enable=piecewise_backend.is_first_graph,
365363
gc_disable=not piecewise_backend.is_first_graph,
@@ -405,7 +403,6 @@ class VllmBackend:
405403

406404
vllm_config: VllmConfig
407405
compilation_config: CompilationConfig
408-
graph_pool: Any
409406
_called: bool = False
410407
# the graph we compiled
411408
graph: fx.GraphModule
@@ -433,13 +430,6 @@ def __init__(
433430
# them, e.g. backbone (default), eagle_head, etc.
434431
self.prefix = prefix or model_tag
435432

436-
global_graph_pool = current_platform.get_global_graph_pool()
437-
438-
# TODO: in the future, if we want to use multiple
439-
# streams, it might not be safe to share a global pool.
440-
# only investigate this when we use multiple streams
441-
self.graph_pool = global_graph_pool
442-
443433
# Passes to run on the graph post-grad.
444434
self.post_grad_pass_manager = PostGradPassManager()
445435

@@ -586,7 +576,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
586576
# propagate the split graph to the piecewise backend,
587577
# compile submodules with symbolic shapes
588578
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
589-
self.vllm_config, self.graph_pool,
579+
self.vllm_config,
590580
self).run(*example_inputs)
591581

592582
graph_path = os.path.join(local_cache_dir, "computation_graph.py")

vllm/compilation/base_static_graph.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class AbstractStaticGraphWrapper(Protocol):
1313
"""
1414

1515
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
16-
runtime_mode: CUDAGraphMode, graph_pool: Any, **kwargs):
16+
runtime_mode: CUDAGraphMode, **kwargs):
1717
"""
1818
Initializes the StaticGraphWrapper class with graph capturing and
1919
execution-related configurations.
@@ -25,9 +25,6 @@ def __init__(self, runnable: Callable, vllm_config: VllmConfig,
2525
graph runtime. See CUDAGraphMode in vllm/config.py.
2626
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
2727
are used as concrete runtime mode for cudagraph dispatching.
28-
graph_pool (Any):
29-
Graph memory pool handle, e.g.,
30-
`torch.cuda.graph_pool_handle()`.
3128
Keyword Args:
3229
kwargs: Additional keyword arguments for platform-specific
3330
configurations.

vllm/compilation/cuda_graph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,9 @@ def __init__(self,
6767
runnable: Callable,
6868
vllm_config: VllmConfig,
6969
runtime_mode: CUDAGraphMode,
70-
graph_pool: Any = None,
7170
cudagraph_options: Optional[CUDAGraphOptions] = None):
7271
self.runnable = runnable
7372
self.vllm_config = vllm_config
74-
self.graph_pool = graph_pool
7573
self.runtime_mode = runtime_mode
7674
self.compilation_config = vllm_config.compilation_config
7775

@@ -81,8 +79,10 @@ def __init__(self,
8179
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
8280
# need to initialize a CUDAGraphWrapper.
8381
assert self.runtime_mode != CUDAGraphMode.NONE
84-
if self.graph_pool is None:
85-
self.graph_pool = current_platform.get_global_graph_pool()
82+
# TODO: in the future, if we want to use multiple
83+
# streams, it might not be safe to share a global pool.
84+
# only investigate this when we use multiple streams
85+
self.graph_pool = current_platform.get_global_graph_pool()
8686

8787
if cudagraph_options is None:
8888
cudagraph_options = CUDAGraphOptions()

0 commit comments

Comments
 (0)