Skip to content

Commit 3d4425e

Browse files
authored
[graph_trainer] Annotate ac region id for transformer blocks (#2609)
Without per-transformer-block AC region IDs, the min-cut partitioner sees the entire model as a single region. In practice, the partitioner can still rely on existing `MUST_SAVE` nodes as anchors to limit recomputation scope. But recomputation could trace all the way back to the beginning of the model when it doesn't hit `MUST_SAVE node. By assigning a unique `ac_graph_id` to each transformer block, the partitioner is forced to `MUST_SAVE` at region boundaries (i.e., between transformer blocks). This ensures recomputation during the backward pass is always contained within a single block. This PR: - Adds `annotate_ac_regions()` to tag each transformer block's forward with a unique `ac_region_id`. - Updates `apply_sac_pass` to read the `ac_region_id` from node custom metadata and set it as the `ac_graph_id`.
1 parent 87920ca commit 3d4425e

File tree

5 files changed

+124
-35
lines changed

5 files changed

+124
-35
lines changed

torchtitan/experiments/graph_trainer/common_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,30 @@
88

99
import torch
1010
import torch.distributed as dist
11+
import torch.nn as nn
1112
from torch.distributed.tensor import DTensor, Replicate
13+
from torch.fx.traceback import annotate_fn
1214
from torch.utils._pytree import register_pytree_node, tree_map
1315

1416
from torchtitan.config import CompileConfig
1517
from torchtitan.distributed import ParallelDims
1618
from torchtitan.tools.logging import logger
1719

20+
_AC_REGION_ID = "ac_region_id"
21+
22+
23+
def annotate_ac_regions(model: nn.Module) -> None:
24+
"""Annotate each transformer block with a unique AC region ID.
25+
26+
This enables apply_sac_pass to assign different ac_graph_id values
27+
per block, creating AC region boundaries between transformer blocks.
28+
"""
29+
layers = model.get_submodule("layers")
30+
for layer_id, transformer_block in layers.named_children():
31+
transformer_block.forward = annotate_fn({_AC_REGION_ID: int(layer_id)})(
32+
transformer_block.forward
33+
)
34+
1835

1936
def parallelize_inputs(parallel_dims, args, kwargs):
2037
if not parallel_dims.tp_enabled:

torchtitan/experiments/graph_trainer/deepseek_v3/parallelize.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import torch.nn as nn
87
from torch.distributed.device_mesh import DeviceMesh
98
from torch.fx.traceback import annotate_fn
109

@@ -20,8 +19,14 @@
2019

2120
from torchtitan.distributed.activation_checkpoint import apply_ac
2221
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
23-
from torchtitan.experiments.graph_trainer.common_utils import maybe_disable_eager_ac
22+
from torchtitan.experiments.graph_trainer.common_utils import (
23+
annotate_ac_regions,
24+
maybe_disable_eager_ac,
25+
)
2426
from torchtitan.experiments.graph_trainer.compile import apply_compile
27+
from torchtitan.experiments.graph_trainer.deepseek_v3.model import (
28+
GraphTrainerDeepSeekV3Model,
29+
)
2530
from torchtitan.experiments.graph_trainer.simple_fsdp import (
2631
data_parallel,
2732
MixedPrecisionPolicy,
@@ -31,7 +36,7 @@
3136
from torchtitan.tools.logging import logger
3237

3338

34-
def annotate_deepseekv3() -> None:
39+
def annotate_deepseekv3(model: GraphTrainerDeepSeekV3Model) -> None:
3540
"""Attach annotations to FX graph nodes with ``torch.fx.traceback.annotate_fn``
3641
3742
- Expert Parallel (EP) annotations: Tags "dispatch", "combine", and "compute"
@@ -40,6 +45,9 @@ def annotate_deepseekv3() -> None:
4045
{"compile_with_inductor": "flex_attention"} so the compiler can apply
4146
regional inductor pass based on the annotation. Regional inductor is now only
4247
supported in AOT mode.
48+
- AC region annotation: Tags each transformer block's forward with a unique
49+
ac_region_id so that apply_sac_pass can assign per-block ac_graph_id
50+
boundaries for the min-cut partitioner.
4351
4452
"""
4553
from torchtitan.distributed.expert_parallel import ExpertParallel
@@ -58,10 +66,12 @@ def annotate_deepseekv3() -> None:
5866
{"compile_with_inductor": "flex_attention"}
5967
)(FlexAttentionWrapper.forward)
6068

69+
annotate_ac_regions(model)
70+
6171

6272
# Adapted from llama4/infra/parallelize.py
6373
def parallelize_deepseekv3(
64-
model: nn.Module,
74+
model: GraphTrainerDeepSeekV3Model,
6575
*,
6676
parallel_dims: ParallelDims,
6777
training: TrainingConfig,
@@ -87,7 +97,7 @@ def parallelize_deepseekv3(
8797
):
8898
raise NotImplementedError("CP support is only supported for SDPA.")
8999

90-
annotate_deepseekv3()
100+
annotate_deepseekv3(model)
91101

92102
maybe_disable_eager_ac(compile_config, ac_config)
93103

torchtitan/experiments/graph_trainer/llama3/parallelize.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,20 @@
1818
from torchtitan.distributed import ParallelDims
1919
from torchtitan.distributed.activation_checkpoint import apply_ac
2020
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
21-
from torchtitan.experiments.graph_trainer.common_utils import maybe_disable_eager_ac
21+
from torchtitan.experiments.graph_trainer.common_utils import (
22+
annotate_ac_regions,
23+
maybe_disable_eager_ac,
24+
)
2225
from torchtitan.experiments.graph_trainer.compile import apply_compile
26+
from torchtitan.experiments.graph_trainer.llama3.model import GraphTrainerLlama3Model
2327
from torchtitan.experiments.graph_trainer.simple_fsdp import (
2428
data_parallel,
2529
MixedPrecisionPolicy,
2630
)
27-
from torchtitan.models.llama3.model import Llama3Model
2831
from torchtitan.models.llama3.parallelize import apply_tp
2932
from torchtitan.protocols.model_converter import ModelConvertersContainer
3033
from torchtitan.tools.logging import logger
3134

32-
3335
# for selective op activation checkpointing
3436
_op_sac_save_list = {
3537
torch.ops.aten.mm.default,
@@ -50,24 +52,29 @@
5052
}
5153

5254

53-
def annotate_llama() -> None:
55+
def annotate_llama(model: GraphTrainerLlama3Model) -> None:
5456
"""Attach annotations to FX graph nodes with ``torch.fx.traceback.annotate_fn``
5557
5658
- Flex attention annotation: Tags FlexAttentionWrapper.forward with
5759
{"compile_with_inductor": "flex_attention"} so the compiler can apply
5860
regional inductor pass based on the annotation. Regional inductor is now only
5961
supported in AOT mode.
6062
63+
- AC region annotation: Tags each transformer block's forward with a unique
64+
ac_region_id so that apply_sac_pass can assign per-block ac_graph_id
65+
boundaries for the min-cut partitioner.
6166
"""
6267
from torchtitan.models.common.attention import FlexAttentionWrapper
6368

6469
FlexAttentionWrapper.forward = annotate_fn(
6570
{"compile_with_inductor": "flex_attention"}
6671
)(FlexAttentionWrapper.forward)
6772

73+
annotate_ac_regions(model)
74+
6875

6976
def parallelize_llama(
70-
model: Llama3Model,
77+
model: GraphTrainerLlama3Model,
7178
*,
7279
parallel_dims: ParallelDims,
7380
training: TrainingConfig,
@@ -94,7 +101,7 @@ def parallelize_llama(
94101
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
95102
"""
96103

97-
annotate_llama()
104+
annotate_llama(model)
98105

99106
maybe_disable_eager_ac(compile_config, ac_config)
100107

torchtitan/experiments/graph_trainer/passes.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
- Compiler passes: Applied to the partitioned forward/backward graphs
1616
"""
1717
import operator
18+
from collections import defaultdict
1819
from collections.abc import Sequence
1920
from typing import Any
2021

@@ -29,6 +30,7 @@
2930
from torch.fx.passes.regional_inductor import regional_inductor
3031
from torch.utils.checkpoint import CheckpointPolicy
3132

33+
from torchtitan.experiments.graph_trainer.common_utils import _AC_REGION_ID
3234
from torchtitan.experiments.graph_trainer.reshard_after_forward import (
3335
annotate_fsdp_all_gather,
3436
)
@@ -182,6 +184,9 @@ def apply_sac_pass(
182184
op_list_to_save = _get_default_sac_save_ops()
183185

184186
mm_count = 0
187+
ac_region_stats: dict[int, dict[str, int]] = defaultdict(
188+
lambda: {"save": 0, "recompute": 0}
189+
)
185190

186191
for node in gm.graph.nodes:
187192
if node.op != "call_function":
@@ -205,25 +210,37 @@ def apply_sac_pass(
205210
node.meta["ac_graph_id"] = parent.meta.get("ac_graph_id", 0)
206211
continue
207212

208-
node.meta["ac_graph_id"] = 0
213+
custom_meta = node.meta.get("custom", {})
214+
ac_region_id = custom_meta.get(_AC_REGION_ID, 0)
215+
node.meta["ac_graph_id"] = ac_region_id
209216

210217
if node.target is torch.ops.aten.mm.default:
211218
mm_count += 1
212219
# Save every odd mm, recompute every even mm
213220
if mm_count % 2 == 0:
214-
node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE
221+
policy = CheckpointPolicy.PREFER_RECOMPUTE
215222
else:
216-
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
223+
policy = CheckpointPolicy.MUST_SAVE
217224
elif node.target in op_list_to_save:
218-
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
225+
policy = CheckpointPolicy.MUST_SAVE
226+
else:
227+
policy = CheckpointPolicy.PREFER_RECOMPUTE
228+
229+
node.meta["recompute"] = policy
230+
if policy == CheckpointPolicy.MUST_SAVE:
231+
ac_region_stats[ac_region_id]["save"] += 1
219232
else:
220-
node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE
233+
ac_region_stats[ac_region_id]["recompute"] += 1
221234

222235
gm.recompile()
223-
logger.info(
224-
"Applied selective activation checkpointing (SAC) graph pass "
225-
f"({mm_count} mm ops found, {mm_count - mm_count // 2} saved)"
226-
)
236+
logger.info("Applied selective activation checkpointing (SAC) graph pass.")
237+
for ac_region_id in sorted(ac_region_stats):
238+
stats = ac_region_stats[ac_region_id]
239+
logger.info(
240+
f" AC region {ac_region_id}: "
241+
f"{stats['save']} nodes annotated with MUST_SAVE, "
242+
f"{stats['recompute']} nodes annotated with PREFER_RECOMPUTE"
243+
)
227244
return gm
228245

229246

torchtitan/experiments/graph_trainer/tests/test_passes.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch.utils.checkpoint import checkpoint, CheckpointPolicy
1919

2020
from torchtitan.distributed import ParallelDims
21+
from torchtitan.experiments.graph_trainer.common_utils import _AC_REGION_ID
2122
from torchtitan.experiments.graph_trainer.graph_utils import export_joint
2223
from torchtitan.experiments.graph_trainer.passes import (
2324
apply_sac_pass,
@@ -215,11 +216,16 @@ def _build_gm(self, op_targets):
215216
x = graph.placeholder("x")
216217
y = graph.placeholder("y")
217218
last = x
218-
for target in op_targets:
219+
for i, target in enumerate(op_targets):
219220
if target is operator.getitem:
220221
last = graph.call_function(target, args=(last, 0))
221222
else:
222223
last = graph.call_function(target, args=(last, y))
224+
# If the next op is getitem, wrap in a tuple so getitem has
225+
# a proper tuple/list input.
226+
if i + 1 < len(op_targets) and op_targets[i + 1] is operator.getitem:
227+
_make_tuple = lambda x: (x, x)
228+
last = graph.call_function(_make_tuple, args=(last,))
223229
graph.output(last)
224230
return torch.fx.GraphModule(torch.nn.Module(), graph)
225231

@@ -248,41 +254,55 @@ def test_save_ops_marked_must_save(self):
248254
self.assertEqual(len(nodes), 1)
249255
self.assertEqual(nodes[0].meta["recompute"], CheckpointPolicy.MUST_SAVE)
250256

251-
def test_getitem_propagates_parent_tag(self):
252-
"""operator.getitem nodes should inherit the parent's recompute tag."""
257+
def test_getitem_propagates_parent_tags(self):
258+
"""operator.getitem nodes should inherit the parent's recompute tag and ac_graph_id."""
253259
gm = self._build_gm(
254260
[
255261
torch.ops.aten.add.Tensor,
256262
operator.getitem,
257263
torch.ops.aten.relu.default,
258264
]
259265
)
260-
apply_sac_pass(gm)
261266
nodes = self._get_call_function_nodes(gm)
262-
add_node = nodes[0]
263-
getitem_node = nodes[1]
264-
self.assertEqual(add_node.target, torch.ops.aten.add.Tensor)
265-
self.assertEqual(getitem_node.target, operator.getitem)
266-
self.assertEqual(getitem_node.meta["recompute"], add_node.meta["recompute"])
267-
268-
def test_wait_tensor_propagates_parent_tag(self):
269-
"""wait_tensor nodes should inherit the parent's recompute tag."""
267+
# nodes: [add, make_tuple, getitem, relu]
268+
# make_tuple is the tuple-returning parent of getitem
269+
self.assertEqual(nodes[0].target, torch.ops.aten.add.Tensor)
270+
self.assertEqual(nodes[2].target, operator.getitem)
271+
272+
# Set ac_region_id on the tuple-returning parent (the direct parent of getitem)
273+
nodes[1].meta["custom"] = {_AC_REGION_ID: 3}
274+
275+
apply_sac_pass(gm)
276+
277+
tuple_node = nodes[1]
278+
getitem_node = nodes[2]
279+
self.assertEqual(getitem_node.meta["recompute"], tuple_node.meta["recompute"])
280+
self.assertEqual(tuple_node.meta["ac_graph_id"], 3)
281+
self.assertEqual(getitem_node.meta["ac_graph_id"], 3)
282+
283+
def test_wait_tensor_propagates_parent_tags(self):
284+
"""wait_tensor nodes should inherit the parent's recompute tag and ac_graph_id."""
270285
custom_save = {torch.ops._c10d_functional.reduce_scatter_tensor.default}
271286
gm = self._build_gm(
272287
[
273288
torch.ops._c10d_functional.reduce_scatter_tensor.default,
274289
torch.ops._c10d_functional.wait_tensor.default,
275290
]
276291
)
277-
apply_sac_pass(gm, op_list_to_save=custom_save)
278292
nodes = self._get_call_function_nodes(gm)
293+
nodes[0].meta["custom"] = {_AC_REGION_ID: 3}
294+
295+
apply_sac_pass(gm, op_list_to_save=custom_save)
296+
279297
rs_node = nodes[0]
280298
wait_node = nodes[1]
281299
self.assertEqual(rs_node.meta["recompute"], CheckpointPolicy.MUST_SAVE)
282300
self.assertEqual(wait_node.meta["recompute"], CheckpointPolicy.MUST_SAVE)
301+
self.assertEqual(rs_node.meta["ac_graph_id"], 3)
302+
self.assertEqual(wait_node.meta["ac_graph_id"], 3)
283303

284-
def test_ac_graph_id_set(self):
285-
"""All annotated nodes should have ac_graph_id = 0."""
304+
def test_ac_graph_id_defaults_to_zero(self):
305+
"""Nodes without ac_region_id annotation should have ac_graph_id = 0."""
286306
gm = self._build_gm(
287307
[
288308
torch.ops.aten.add.Tensor,
@@ -295,6 +315,24 @@ def test_ac_graph_id_set(self):
295315
if node.target is not operator.getitem:
296316
self.assertEqual(node.meta["ac_graph_id"], 0)
297317

318+
def test_ac_graph_id_from_annotation(self):
319+
"""Nodes with _AC_REGION_ID_KEY in custom metadata should use that as ac_graph_id."""
320+
gm = self._build_gm(
321+
[
322+
torch.ops.aten.add.Tensor,
323+
torch.ops.aten.relu.default,
324+
]
325+
)
326+
nodes = self._get_call_function_nodes(gm)
327+
# Simulate annotate_fn setting custom metadata on different nodes
328+
nodes[0].meta["custom"] = {_AC_REGION_ID: 1}
329+
nodes[1].meta["custom"] = {_AC_REGION_ID: 2}
330+
331+
apply_sac_pass(gm)
332+
333+
self.assertEqual(nodes[0].meta["ac_graph_id"], 1)
334+
self.assertEqual(nodes[1].meta["ac_graph_id"], 2)
335+
298336
def test_custom_op_list_to_save(self):
299337
"""A custom op_list_to_save should override the defaults."""
300338
custom_save = {torch.ops.aten.relu.default}

0 commit comments

Comments
 (0)