@@ -294,13 +294,12 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
294
294
295
295
def __init__ (self , module : torch .fx .GraphModule ,
296
296
compile_submod_names : list [str ], vllm_config : VllmConfig ,
297
- graph_pool , vllm_backend : "VllmBackend" ):
297
+ vllm_backend : "VllmBackend" ):
298
298
super ().__init__ (module )
299
299
from torch ._guards import detect_fake_mode
300
300
self .fake_mode = detect_fake_mode ()
301
301
self .compile_submod_names = compile_submod_names
302
302
self .compilation_config = vllm_config .compilation_config
303
- self .graph_pool = graph_pool
304
303
self .vllm_config = vllm_config
305
304
self .vllm_backend = vllm_backend
306
305
# When True, it annoyingly dumps the torch.fx.Graph on errors.
@@ -359,7 +358,6 @@ def call_module(self, target: torch.fx.node.Target,
359
358
runnable = piecewise_backend ,
360
359
vllm_config = self .vllm_config ,
361
360
runtime_mode = CUDAGraphMode .PIECEWISE ,
362
- graph_pool = self .graph_pool ,
363
361
cudagraph_options = CUDAGraphOptions (
364
362
debug_log_enable = piecewise_backend .is_first_graph ,
365
363
gc_disable = not piecewise_backend .is_first_graph ,
@@ -405,7 +403,6 @@ class VllmBackend:
405
403
406
404
vllm_config : VllmConfig
407
405
compilation_config : CompilationConfig
408
- graph_pool : Any
409
406
_called : bool = False
410
407
# the graph we compiled
411
408
graph : fx .GraphModule
@@ -433,13 +430,6 @@ def __init__(
433
430
# them, e.g. backbone (default), eagle_head, etc.
434
431
self .prefix = prefix or model_tag
435
432
436
- global_graph_pool = current_platform .get_global_graph_pool ()
437
-
438
- # TODO: in the future, if we want to use multiple
439
- # streams, it might not be safe to share a global pool.
440
- # only investigate this when we use multiple streams
441
- self .graph_pool = global_graph_pool
442
-
443
433
# Passes to run on the graph post-grad.
444
434
self .post_grad_pass_manager = PostGradPassManager ()
445
435
@@ -586,7 +576,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
586
576
# propagate the split graph to the piecewise backend,
587
577
# compile submodules with symbolic shapes
588
578
PiecewiseCompileInterpreter (self .split_gm , submod_names_to_compile ,
589
- self .vllm_config , self . graph_pool ,
579
+ self .vllm_config ,
590
580
self ).run (* example_inputs )
591
581
592
582
graph_path = os .path .join (local_cache_dir , "computation_graph.py" )
0 commit comments