Commit 3d4425e
authored
[graph_trainer] Annotate ac region id for transformer blocks (#2609)
Without per-transformer-block AC region IDs, the min-cut partitioner
sees the entire model as a single region. In practice, the partitioner
can still rely on existing `MUST_SAVE` nodes as anchors to limit
recomputation scope. But recomputation could trace all the way back to
the beginning of the model when it doesn't hit `MUST_SAVE node.
By assigning a unique `ac_graph_id` to each transformer block, the
partitioner is forced to `MUST_SAVE` at region boundaries (i.e., between
transformer blocks). This ensures recomputation during the backward pass
is always contained within a single block.
This PR:
- Adds `annotate_ac_regions()` to tag each transformer block's forward
with a unique `ac_region_id`.
- Updates `apply_sac_pass` to read the `ac_region_id` from node custom
metadata and set it as the `ac_graph_id`.1 parent 87920ca commit 3d4425e
File tree
5 files changed
+124
-35
lines changed- torchtitan/experiments/graph_trainer
- deepseek_v3
- llama3
- tests
5 files changed
+124
-35
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
| 11 | + | |
11 | 12 | | |
| 13 | + | |
12 | 14 | | |
13 | 15 | | |
14 | 16 | | |
15 | 17 | | |
16 | 18 | | |
17 | 19 | | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
18 | 35 | | |
19 | 36 | | |
20 | 37 | | |
| |||
Lines changed: 15 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
7 | | - | |
8 | 7 | | |
9 | 8 | | |
10 | 9 | | |
| |||
20 | 19 | | |
21 | 20 | | |
22 | 21 | | |
23 | | - | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
24 | 26 | | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
25 | 30 | | |
26 | 31 | | |
27 | 32 | | |
| |||
31 | 36 | | |
32 | 37 | | |
33 | 38 | | |
34 | | - | |
| 39 | + | |
35 | 40 | | |
36 | 41 | | |
37 | 42 | | |
| |||
40 | 45 | | |
41 | 46 | | |
42 | 47 | | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
43 | 51 | | |
44 | 52 | | |
45 | 53 | | |
| |||
58 | 66 | | |
59 | 67 | | |
60 | 68 | | |
| 69 | + | |
| 70 | + | |
61 | 71 | | |
62 | 72 | | |
63 | 73 | | |
64 | | - | |
| 74 | + | |
65 | 75 | | |
66 | 76 | | |
67 | 77 | | |
| |||
87 | 97 | | |
88 | 98 | | |
89 | 99 | | |
90 | | - | |
| 100 | + | |
91 | 101 | | |
92 | 102 | | |
93 | 103 | | |
| |||
Lines changed: 13 additions & 6 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
21 | | - | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
22 | 25 | | |
| 26 | + | |
23 | 27 | | |
24 | 28 | | |
25 | 29 | | |
26 | 30 | | |
27 | | - | |
28 | 31 | | |
29 | 32 | | |
30 | 33 | | |
31 | 34 | | |
32 | | - | |
33 | 35 | | |
34 | 36 | | |
35 | 37 | | |
| |||
50 | 52 | | |
51 | 53 | | |
52 | 54 | | |
53 | | - | |
| 55 | + | |
54 | 56 | | |
55 | 57 | | |
56 | 58 | | |
57 | 59 | | |
58 | 60 | | |
59 | 61 | | |
60 | 62 | | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
61 | 66 | | |
62 | 67 | | |
63 | 68 | | |
64 | 69 | | |
65 | 70 | | |
66 | 71 | | |
67 | 72 | | |
| 73 | + | |
| 74 | + | |
68 | 75 | | |
69 | 76 | | |
70 | | - | |
| 77 | + | |
71 | 78 | | |
72 | 79 | | |
73 | 80 | | |
| |||
94 | 101 | | |
95 | 102 | | |
96 | 103 | | |
97 | | - | |
| 104 | + | |
98 | 105 | | |
99 | 106 | | |
100 | 107 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
| 18 | + | |
18 | 19 | | |
19 | 20 | | |
20 | 21 | | |
| |||
29 | 30 | | |
30 | 31 | | |
31 | 32 | | |
| 33 | + | |
32 | 34 | | |
33 | 35 | | |
34 | 36 | | |
| |||
182 | 184 | | |
183 | 185 | | |
184 | 186 | | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
185 | 190 | | |
186 | 191 | | |
187 | 192 | | |
| |||
205 | 210 | | |
206 | 211 | | |
207 | 212 | | |
208 | | - | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
209 | 216 | | |
210 | 217 | | |
211 | 218 | | |
212 | 219 | | |
213 | 220 | | |
214 | | - | |
| 221 | + | |
215 | 222 | | |
216 | | - | |
| 223 | + | |
217 | 224 | | |
218 | | - | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
219 | 232 | | |
220 | | - | |
| 233 | + | |
221 | 234 | | |
222 | 235 | | |
223 | | - | |
224 | | - | |
225 | | - | |
226 | | - | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
227 | 244 | | |
228 | 245 | | |
229 | 246 | | |
| |||
Lines changed: 53 additions & 15 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
| 21 | + | |
21 | 22 | | |
22 | 23 | | |
23 | 24 | | |
| |||
215 | 216 | | |
216 | 217 | | |
217 | 218 | | |
218 | | - | |
| 219 | + | |
219 | 220 | | |
220 | 221 | | |
221 | 222 | | |
222 | 223 | | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
223 | 229 | | |
224 | 230 | | |
225 | 231 | | |
| |||
248 | 254 | | |
249 | 255 | | |
250 | 256 | | |
251 | | - | |
252 | | - | |
| 257 | + | |
| 258 | + | |
253 | 259 | | |
254 | 260 | | |
255 | 261 | | |
256 | 262 | | |
257 | 263 | | |
258 | 264 | | |
259 | 265 | | |
260 | | - | |
261 | 266 | | |
262 | | - | |
263 | | - | |
264 | | - | |
265 | | - | |
266 | | - | |
267 | | - | |
268 | | - | |
269 | | - | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
270 | 285 | | |
271 | 286 | | |
272 | 287 | | |
273 | 288 | | |
274 | 289 | | |
275 | 290 | | |
276 | 291 | | |
277 | | - | |
278 | 292 | | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
279 | 297 | | |
280 | 298 | | |
281 | 299 | | |
282 | 300 | | |
| 301 | + | |
| 302 | + | |
283 | 303 | | |
284 | | - | |
285 | | - | |
| 304 | + | |
| 305 | + | |
286 | 306 | | |
287 | 307 | | |
288 | 308 | | |
| |||
295 | 315 | | |
296 | 316 | | |
297 | 317 | | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
298 | 336 | | |
299 | 337 | | |
300 | 338 | | |
| |||
0 commit comments