|
11 | 11 | during compilation. |
12 | 12 | """ |
13 | 13 |
|
| 14 | +import logging |
14 | 15 | import warnings |
15 | 16 | from collections.abc import Callable, Sequence |
16 | 17 | from typing import Any |
|
19 | 20 | from torch._inductor.cudagraph_trees import _use_cuda_memory_pool_manager |
20 | 21 | from torch.utils._ordered_set import OrderedSet |
21 | 22 |
|
22 | | - |
23 | | -def init_global_graph_pool() -> tuple[ |
24 | | - torch.cuda.CUDAGraph, torch.cuda._POOL_HANDLE, torch.cuda.Stream |
25 | | -]: |
26 | | - dummy_graph = torch.cuda.CUDAGraph() |
27 | | - |
28 | | - # create a global cudagraph memory pool to allow memory reuse across cudagraphs. |
29 | | - graph_pool = torch.cuda.graph_pool_handle() |
30 | | - |
31 | | - # create a global cuda stream for graph capture. we need to use a single stream |
32 | | - # for all allocations to the memory pool, otherwise the allocations to separate streams |
33 | | - # will not be used. |
34 | | - graph_capture_stream = torch.cuda.Stream() |
35 | | - |
36 | | - # use a dummy graph to keep the global graph pool alive |
37 | | - with ( |
38 | | - # suppress an empty cudagraph warning, since we intentionally create |
39 | | - # an empty cudagraph here |
40 | | - warnings.catch_warnings(record=True), |
41 | | - torch.cuda.graph( |
42 | | - dummy_graph, |
43 | | - pool=graph_pool, |
44 | | - stream=graph_capture_stream, |
45 | | - capture_error_mode="thread_local", |
46 | | - ), |
47 | | - ): |
48 | | - pass |
49 | | - |
50 | | - return dummy_graph, graph_pool, graph_capture_stream |
51 | | - |
52 | | - |
53 | | -( |
54 | | - _global_dummy_graph, |
55 | | - _global_graph_pool, |
56 | | - _global_graph_capture_stream, |
57 | | -) = init_global_graph_pool() |
| 23 | +logger = logging.getLogger(__name__) |
| 24 | + |
| 25 | + |
| 26 | +class _CUDAGraphManager: |
| 27 | + """A manager to hold a shared graph pool, stream, and wrapper registry.""" |
| 28 | + |
| 29 | + def __init__(self) -> None: |
| 30 | + self._initialized = False |
| 31 | + self._cudagraph_wrappers: list["CUDAGraphWrapper"] = [] |
| 32 | + self._teardown_called = False |
| 33 | + |
| 34 | + def maybe_initialize(self) -> None: |
| 35 | + if self._initialized: |
| 36 | + return |
| 37 | + |
| 38 | + self._initialized = True |
| 39 | + |
| 40 | + # create a global cudagraph memory pool to allow memory reuse across cudagraphs. |
| 41 | + self.graph_pool = torch.cuda.graph_pool_handle() |
| 42 | + |
| 43 | + # create a global cuda stream for graph capture. we need to use a single stream |
| 44 | + # for all allocations to the memory pool, otherwise the allocations to separate |
| 45 | + # streams will not be used. |
| 46 | + self.stream = torch.cuda.Stream() |
| 47 | + |
| 48 | + # use a dummy graph to keep the global graph pool alive |
| 49 | + self._dummy_graph = torch.cuda.CUDAGraph() |
| 50 | + with ( |
| 51 | + # suppress an empty cudagraph warning, since we intentionally create |
| 52 | + # an empty cudagraph here |
| 53 | + warnings.catch_warnings(record=True), |
| 54 | + torch.cuda.graph( |
| 55 | + self._dummy_graph, |
| 56 | + pool=self.graph_pool, |
| 57 | + stream=self.stream, |
| 58 | + capture_error_mode="thread_local", |
| 59 | + ), |
| 60 | + ): |
| 61 | + pass |
| 62 | + |
| 63 | + def register_wrapper(self, wrapper: "CUDAGraphWrapper") -> None: |
| 64 | + assert not self._teardown_called, "Cannot register new cudagraph after teardown" |
| 65 | + self._cudagraph_wrappers.append(wrapper) |
| 66 | + |
| 67 | + def teardown(self) -> None: |
| 68 | + """Destroy all cudagraphs and release the cudagraph memory pool. |
| 69 | +
|
| 70 | + Note [explicit cudagraph teardown] |
| 71 | + cudagraph holds reference to nccl which prevents destroy process |
| 72 | + group. so we need to explicitly delete cudagraph which is held |
| 73 | + in _CUDAGraphManager and CUDAGraphWrapper. If cudagraph is not |
| 74 | + used, this is a no-op. |
| 75 | + """ |
| 76 | + if not self._initialized: |
| 77 | + return |
| 78 | + if self._teardown_called: |
| 79 | + logger.warning("cudagraph manager teardown called twice") |
| 80 | + return |
| 81 | + |
| 82 | + for wrapper in self._cudagraph_wrappers: |
| 83 | + wrapper.teardown() |
| 84 | + self._cudagraph_wrappers.clear() |
| 85 | + |
| 86 | + self._dummy_graph = None |
| 87 | + self.stream = None |
| 88 | + self.graph_pool = None |
| 89 | + self._teardown_called = True |
| 90 | + |
| 91 | + |
| 92 | +_cg_manager = _CUDAGraphManager() |
| 93 | + |
| 94 | + |
| 95 | +def cudagraph_teardown() -> None: |
| 96 | + """Destroy all cudagraphs and release the cudagraph memory pool. |
| 97 | + See Note [explicit cudagraph teardown] for more details. |
| 98 | + """ |
| 99 | + _cg_manager.teardown() |
58 | 100 |
|
59 | 101 |
|
60 | 102 | class CUDAGraphWrapper: |
| 103 | + """Wraps a callable with cudagraph. It warms up the callable, records cudagraph, |
| 104 | + and replays cudagraph during runtime. It also handles static input tensors, which |
| 105 | + are tensors whose tensor addresses do not change across runs. |
| 106 | +
|
| 107 | + Args: |
| 108 | + runnable: The callable to wrap with CUDA graph. This can be a |
| 109 | + torch.fx.GraphModule when used in an FX graph pass, or any |
| 110 | + callable when used in PyTorch eager mode. |
| 111 | + example_inputs: A list of example inputs to the callable. |
| 112 | + static_input_indices: A tuple of indices identifying static input |
| 113 | + tensors. Static inputs are tensors whose memory addresses remain |
| 114 | + constant across invocations. Common examples include model weights, |
| 115 | + buffers, and outputs from previously wrapped CUDA graph functions. |
| 116 | + should_check_address: Whether to verify static input tensor addresses |
| 117 | + at runtime. This should only be enabled for debugging purposes. |
| 118 | + """ |
| 119 | + |
61 | 120 | def __init__( |
62 | 121 | self, |
63 | 122 | runnable: Callable, |
64 | 123 | example_inputs: Sequence[Any], |
65 | 124 | static_input_indices: tuple[int] | None = None, |
66 | 125 | should_check_address: bool = False, |
67 | 126 | ): |
68 | | - self.runnable = runnable |
69 | | - self.graph_pool = _global_graph_pool |
70 | | - self.stream = _global_graph_capture_stream |
71 | | - self.static_input_indices = OrderedSet( |
| 127 | + _cg_manager.maybe_initialize() |
| 128 | + _cg_manager.register_wrapper(self) |
| 129 | + |
| 130 | + self._runnable = runnable |
| 131 | + self._static_input_indices = OrderedSet( |
72 | 132 | static_input_indices if static_input_indices is not None else [] |
73 | 133 | ) |
74 | | - self.input_indices_to_copy = [ |
| 134 | + self._input_indices_to_copy = [ |
75 | 135 | i |
76 | 136 | for i, inp in enumerate(example_inputs) |
77 | | - if isinstance(inp, torch.Tensor) and i not in self.static_input_indices |
| 137 | + if isinstance(inp, torch.Tensor) and i not in self._static_input_indices |
78 | 138 | ] |
79 | | - self.cudagraph: torch.cuda.CUDAGraph | None = None |
80 | | - self.has_warmup = False |
| 139 | + self._cudagraph: torch.cuda.CUDAGraph | None = None |
| 140 | + self._has_warmup = False |
81 | 141 |
|
82 | | - self.args = None |
83 | | - self.output = None |
| 142 | + self._args = None |
| 143 | + self._output = None |
84 | 144 |
|
85 | 145 | # (debug only) whether check static input tensor addresses during runtime |
86 | | - self.should_check_address = should_check_address |
| 146 | + self._should_check_address = should_check_address |
87 | 147 |
|
88 | | - def copy_non_static_inputs(self, *args): |
89 | | - for i in self.input_indices_to_copy: |
90 | | - self.args[i].copy_(args[i]) |
| 148 | + def _copy_non_static_inputs(self, *args): |
| 149 | + for i in self._input_indices_to_copy: |
| 150 | + self._args[i].copy_(args[i]) |
91 | 151 |
|
92 | | - def check_input_types(self, inputs) -> None: |
| 152 | + def _check_input_types(self, inputs) -> None: |
93 | 153 | for inp in inputs: |
94 | 154 | assert isinstance(inp, (torch.Tensor, int, torch._C.Generator)), ( |
95 | 155 | "args must be tensor, integer (for dynamic shapes), " |
96 | 156 | "or Generator (for random number generator), " |
97 | 157 | f"but found {type(inp)}" |
98 | 158 | ) |
99 | 159 |
|
100 | | - def check_static_inputs_address(self) -> None: |
101 | | - for i in self.static_input_indices: |
102 | | - actual = self.args[i].data_ptr() |
103 | | - expected = self.input_addresses[i] |
| 160 | + def _check_static_inputs_address(self) -> None: |
| 161 | + for i in self._static_input_indices: |
| 162 | + actual = self._args[i].data_ptr() |
| 163 | + expected = self._input_addresses[i] |
104 | 164 | assert expected == actual, ( |
105 | 165 | "Expected the same static tensor address but found " |
106 | 166 | f"{expected} != {actual}" |
107 | 167 | ) |
108 | 168 |
|
109 | 169 | def __call__(self, *args): |
110 | | - if not self.has_warmup: |
111 | | - self.has_warmup = True |
| 170 | + if not self._has_warmup: |
| 171 | + self._has_warmup = True |
112 | 172 | device = torch.cuda.current_device() |
113 | 173 |
|
114 | 174 | # warmup in cudagraph memory pool to avoid fragmentation |
115 | 175 | # across eager memory pool and cudagraph memory pool. |
116 | | - with _use_cuda_memory_pool_manager(device, self.graph_pool, self.stream): |
117 | | - out = self.runnable(*args) |
| 176 | + with _use_cuda_memory_pool_manager( |
| 177 | + device, _cg_manager.graph_pool, _cg_manager.stream |
| 178 | + ): |
| 179 | + out = self._runnable(*args) |
118 | 180 | return out |
119 | 181 |
|
120 | | - if self.cudagraph is None: |
121 | | - self.check_input_types(args) |
122 | | - self.args = args |
123 | | - self.input_addresses = [ |
| 182 | + if self._cudagraph is None: |
| 183 | + self._check_input_types(args) |
| 184 | + self._args = args |
| 185 | + self._input_addresses = [ |
124 | 186 | x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args |
125 | 187 | ] |
126 | 188 |
|
127 | | - self.cudagraph = torch.cuda.CUDAGraph() |
| 189 | + self._cudagraph = torch.cuda.CUDAGraph() |
128 | 190 |
|
129 | 191 | with torch.cuda.graph( |
130 | | - self.cudagraph, pool=self.graph_pool, stream=self.stream |
| 192 | + self._cudagraph, |
| 193 | + pool=_cg_manager.graph_pool, |
| 194 | + stream=_cg_manager.stream, |
131 | 195 | ): |
132 | 196 | # `output` is managed by pytorch's cudagraph pool |
133 | | - self.output = self.runnable(*args) |
134 | | - |
135 | | - if self.should_check_address: |
136 | | - self.check_static_inputs_address() |
137 | | - |
138 | | - self.copy_non_static_inputs(*args) |
139 | | - self.cudagraph.replay() |
140 | | - return self.output |
| 197 | + self._output = self._runnable(*args) |
| 198 | + |
| 199 | + if self._should_check_address: |
| 200 | + self._check_static_inputs_address() |
| 201 | + |
| 202 | + self._copy_non_static_inputs(*args) |
| 203 | + self._cudagraph.replay() |
| 204 | + return self._output |
| 205 | + |
| 206 | + def teardown(self) -> None: |
| 207 | + """Destroy cudagraph and release references. |
| 208 | + See Note [explicit cudagraph teardown] for more details. |
| 209 | + """ |
| 210 | + self._cudagraph = None |
| 211 | + self._args = None |
| 212 | + self._output = None |
141 | 213 |
|
142 | 214 |
|
143 | 215 | def get_static_input_indices(gm: torch.fx.GraphModule, is_forward: bool) -> list[int]: |
|
0 commit comments