Skip to content

Commit 1a7a3d3

Browse files
authored
[graph_trainer] Add CUDAGraph manager for centralized lifecycle management (#2572)
Summary - Introduce `_CUDAGraphManager` to centralize ownership of the shared CUDA graph pool, stream, and all CUDAGraphWrapper instances - Lazily initialize the graph pool/stream on first use (instead of at module import time) - Replace the fragile teardown in `GraphTrainer.close()` with an explicit `cudagraph_teardown()` that destroys all registered wrappers and releases the pool
1 parent ae039c9 commit 1a7a3d3

File tree

2 files changed

+148
-83
lines changed

2 files changed

+148
-83
lines changed

torchtitan/experiments/graph_trainer/cudagraph.py

Lines changed: 145 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
during compilation.
1212
"""
1313

14+
import logging
1415
import warnings
1516
from collections.abc import Callable, Sequence
1617
from typing import Any
@@ -19,125 +20,196 @@
1920
from torch._inductor.cudagraph_trees import _use_cuda_memory_pool_manager
2021
from torch.utils._ordered_set import OrderedSet
2122

22-
23-
def init_global_graph_pool() -> tuple[
24-
torch.cuda.CUDAGraph, torch.cuda._POOL_HANDLE, torch.cuda.Stream
25-
]:
26-
dummy_graph = torch.cuda.CUDAGraph()
27-
28-
# create a global cudagraph memory pool to allow memory reuse across cudagraphs.
29-
graph_pool = torch.cuda.graph_pool_handle()
30-
31-
# create a global cuda stream for graph capture. we need to use a single stream
32-
# for all allocations to the memory pool, otherwise the allocations to separate streams
33-
# will not be used.
34-
graph_capture_stream = torch.cuda.Stream()
35-
36-
# use a dummy graph to keep the global graph pool alive
37-
with (
38-
# suppress an empty cudagraph warning, since we intentionally create
39-
# an empty cudagraph here
40-
warnings.catch_warnings(record=True),
41-
torch.cuda.graph(
42-
dummy_graph,
43-
pool=graph_pool,
44-
stream=graph_capture_stream,
45-
capture_error_mode="thread_local",
46-
),
47-
):
48-
pass
49-
50-
return dummy_graph, graph_pool, graph_capture_stream
51-
52-
53-
(
54-
_global_dummy_graph,
55-
_global_graph_pool,
56-
_global_graph_capture_stream,
57-
) = init_global_graph_pool()
23+
logger = logging.getLogger(__name__)
24+
25+
26+
class _CUDAGraphManager:
27+
"""A manager to hold a shared graph pool, stream, and wrapper registry."""
28+
29+
def __init__(self) -> None:
30+
self._initialized = False
31+
self._cudagraph_wrappers: list["CUDAGraphWrapper"] = []
32+
self._teardown_called = False
33+
34+
def maybe_initialize(self) -> None:
35+
if self._initialized:
36+
return
37+
38+
self._initialized = True
39+
40+
# create a global cudagraph memory pool to allow memory reuse across cudagraphs.
41+
self.graph_pool = torch.cuda.graph_pool_handle()
42+
43+
# create a global cuda stream for graph capture. we need to use a single stream
44+
# for all allocations to the memory pool, otherwise the allocations to separate
45+
# streams will not be used.
46+
self.stream = torch.cuda.Stream()
47+
48+
# use a dummy graph to keep the global graph pool alive
49+
self._dummy_graph = torch.cuda.CUDAGraph()
50+
with (
51+
# suppress an empty cudagraph warning, since we intentionally create
52+
# an empty cudagraph here
53+
warnings.catch_warnings(record=True),
54+
torch.cuda.graph(
55+
self._dummy_graph,
56+
pool=self.graph_pool,
57+
stream=self.stream,
58+
capture_error_mode="thread_local",
59+
),
60+
):
61+
pass
62+
63+
def register_wrapper(self, wrapper: "CUDAGraphWrapper") -> None:
64+
assert not self._teardown_called, "Cannot register new cudagraph after teardown"
65+
self._cudagraph_wrappers.append(wrapper)
66+
67+
def teardown(self) -> None:
68+
"""Destroy all cudagraphs and release the cudagraph memory pool.
69+
70+
Note [explicit cudagraph teardown]
71+
cudagraph holds reference to nccl which prevents destroy process
72+
group. so we need to explicitly delete cudagraph which is held
73+
in _CUDAGraphManager and CUDAGraphWrapper. If cudagraph is not
74+
used, this is a no-op.
75+
"""
76+
if not self._initialized:
77+
return
78+
if self._teardown_called:
79+
logger.warning("cudagraph manager teardown called twice")
80+
return
81+
82+
for wrapper in self._cudagraph_wrappers:
83+
wrapper.teardown()
84+
self._cudagraph_wrappers.clear()
85+
86+
self._dummy_graph = None
87+
self.stream = None
88+
self.graph_pool = None
89+
self._teardown_called = True
90+
91+
92+
_cg_manager = _CUDAGraphManager()
93+
94+
95+
def cudagraph_teardown() -> None:
96+
"""Destroy all cudagraphs and release the cudagraph memory pool.
97+
See Note [explicit cudagraph teardown] for more details.
98+
"""
99+
_cg_manager.teardown()
58100

59101

60102
class CUDAGraphWrapper:
103+
"""Wraps a callable with cudagraph. It warms up the callable, records cudagraph,
104+
and replays cudagraph during runtime. It also handles static input tensors, which
105+
are tensors whose tensor addresses do not change across runs.
106+
107+
Args:
108+
runnable: The callable to wrap with CUDA graph. This can be a
109+
torch.fx.GraphModule when used in an FX graph pass, or any
110+
callable when used in PyTorch eager mode.
111+
example_inputs: A list of example inputs to the callable.
112+
static_input_indices: A tuple of indices identifying static input
113+
tensors. Static inputs are tensors whose memory addresses remain
114+
constant across invocations. Common examples include model weights,
115+
buffers, and outputs from previously wrapped CUDA graph functions.
116+
should_check_address: Whether to verify static input tensor addresses
117+
at runtime. This should only be enabled for debugging purposes.
118+
"""
119+
61120
def __init__(
62121
self,
63122
runnable: Callable,
64123
example_inputs: Sequence[Any],
65124
static_input_indices: tuple[int] | None = None,
66125
should_check_address: bool = False,
67126
):
68-
self.runnable = runnable
69-
self.graph_pool = _global_graph_pool
70-
self.stream = _global_graph_capture_stream
71-
self.static_input_indices = OrderedSet(
127+
_cg_manager.maybe_initialize()
128+
_cg_manager.register_wrapper(self)
129+
130+
self._runnable = runnable
131+
self._static_input_indices = OrderedSet(
72132
static_input_indices if static_input_indices is not None else []
73133
)
74-
self.input_indices_to_copy = [
134+
self._input_indices_to_copy = [
75135
i
76136
for i, inp in enumerate(example_inputs)
77-
if isinstance(inp, torch.Tensor) and i not in self.static_input_indices
137+
if isinstance(inp, torch.Tensor) and i not in self._static_input_indices
78138
]
79-
self.cudagraph: torch.cuda.CUDAGraph | None = None
80-
self.has_warmup = False
139+
self._cudagraph: torch.cuda.CUDAGraph | None = None
140+
self._has_warmup = False
81141

82-
self.args = None
83-
self.output = None
142+
self._args = None
143+
self._output = None
84144

85145
# (debug only) whether check static input tensor addresses during runtime
86-
self.should_check_address = should_check_address
146+
self._should_check_address = should_check_address
87147

88-
def copy_non_static_inputs(self, *args):
89-
for i in self.input_indices_to_copy:
90-
self.args[i].copy_(args[i])
148+
def _copy_non_static_inputs(self, *args):
149+
for i in self._input_indices_to_copy:
150+
self._args[i].copy_(args[i])
91151

92-
def check_input_types(self, inputs) -> None:
152+
def _check_input_types(self, inputs) -> None:
93153
for inp in inputs:
94154
assert isinstance(inp, (torch.Tensor, int, torch._C.Generator)), (
95155
"args must be tensor, integer (for dynamic shapes), "
96156
"or Generator (for random number generator), "
97157
f"but found {type(inp)}"
98158
)
99159

100-
def check_static_inputs_address(self) -> None:
101-
for i in self.static_input_indices:
102-
actual = self.args[i].data_ptr()
103-
expected = self.input_addresses[i]
160+
def _check_static_inputs_address(self) -> None:
161+
for i in self._static_input_indices:
162+
actual = self._args[i].data_ptr()
163+
expected = self._input_addresses[i]
104164
assert expected == actual, (
105165
"Expected the same static tensor address but found "
106166
f"{expected} != {actual}"
107167
)
108168

109169
def __call__(self, *args):
110-
if not self.has_warmup:
111-
self.has_warmup = True
170+
if not self._has_warmup:
171+
self._has_warmup = True
112172
device = torch.cuda.current_device()
113173

114174
# warmup in cudagraph memory pool to avoid fragmentation
115175
# across eager memory pool and cudagraph memory pool.
116-
with _use_cuda_memory_pool_manager(device, self.graph_pool, self.stream):
117-
out = self.runnable(*args)
176+
with _use_cuda_memory_pool_manager(
177+
device, _cg_manager.graph_pool, _cg_manager.stream
178+
):
179+
out = self._runnable(*args)
118180
return out
119181

120-
if self.cudagraph is None:
121-
self.check_input_types(args)
122-
self.args = args
123-
self.input_addresses = [
182+
if self._cudagraph is None:
183+
self._check_input_types(args)
184+
self._args = args
185+
self._input_addresses = [
124186
x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args
125187
]
126188

127-
self.cudagraph = torch.cuda.CUDAGraph()
189+
self._cudagraph = torch.cuda.CUDAGraph()
128190

129191
with torch.cuda.graph(
130-
self.cudagraph, pool=self.graph_pool, stream=self.stream
192+
self._cudagraph,
193+
pool=_cg_manager.graph_pool,
194+
stream=_cg_manager.stream,
131195
):
132196
# `output` is managed by pytorch's cudagraph pool
133-
self.output = self.runnable(*args)
134-
135-
if self.should_check_address:
136-
self.check_static_inputs_address()
137-
138-
self.copy_non_static_inputs(*args)
139-
self.cudagraph.replay()
140-
return self.output
197+
self._output = self._runnable(*args)
198+
199+
if self._should_check_address:
200+
self._check_static_inputs_address()
201+
202+
self._copy_non_static_inputs(*args)
203+
self._cudagraph.replay()
204+
return self._output
205+
206+
def teardown(self) -> None:
207+
"""Destroy cudagraph and release references.
208+
See Note [explicit cudagraph teardown] for more details.
209+
"""
210+
self._cudagraph = None
211+
self._args = None
212+
self._output = None
141213

142214

143215
def get_static_input_indices(gm: torch.fx.GraphModule, is_forward: bool) -> list[int]:

torchtitan/experiments/graph_trainer/trainer.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import gc
87
from dataclasses import dataclass, field
98

109
from torchtitan.experiments.graph_trainer.configs import GraphTrainerCompileConfig
10+
from torchtitan.experiments.graph_trainer.cudagraph import cudagraph_teardown
1111
from torchtitan.trainer import Trainer
1212

1313

@@ -21,12 +21,5 @@ class Config(Trainer.Config):
2121
def close(self) -> None:
2222
super().close()
2323

24-
# Note [explicit cudagraph close]
25-
# cudagraph holds reference to nccl which prevents destroy nccl
26-
# group. so we need to explicitly delete cudagraph which is held
27-
# in joint_graph_module. An explicit gc.collect() is necessary
28-
# to clean up reference cycles.
29-
for part in self.model_parts:
30-
if hasattr(part, "joint_graph_module"):
31-
part.joint_graph_module = None
32-
gc.collect()
24+
# See Note [explicit cudagraph teardown] in cudagraph.py
25+
cudagraph_teardown()

0 commit comments

Comments
 (0)