|
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | 7 | import unittest |
| 8 | +from collections import Counter |
8 | 9 |
|
9 | 10 | import torch |
10 | 11 | import torch.nn as nn |
11 | 12 |
|
12 | 13 | from torchtitan.experiments.graph_trainer.make_fx_tracer import ( |
| 14 | + _copy_fwd_metadata_to_bw_nodes, |
| 15 | + _patch_engine_run_backward, |
13 | 16 | run_traced_module, |
14 | 17 | trace_module, |
15 | 18 | ) |
@@ -255,5 +258,99 @@ def test_dtensor_train_step(self): |
255 | 258 | self.assertTrue(torch.equal(gr.full_tensor(), gt.full_tensor())) |
256 | 259 |
|
257 | 260 |
|
| 261 | +@unittest.skipUnless(torch.cuda.is_available(), "CUDA required") |
| 262 | +class TestMetadataPropagation(unittest.TestCase): |
| 263 | + """Tests for _patch_engine_run_backward and _copy_fwd_metadata_to_bw_nodes.""" |
| 264 | + |
| 265 | + DEVICE = "cuda" |
| 266 | + DTYPE = torch.float32 |
| 267 | + |
| 268 | + def setUp(self): |
| 269 | + torch.manual_seed(42) |
| 270 | + |
| 271 | + def test_backward_nodes_have_seq_nr(self): |
| 272 | + """Verify that backward FX nodes get seq_nr metadata via the patched engine.""" |
| 273 | + model = SimpleMLP().to(device=self.DEVICE, dtype=self.DTYPE) |
| 274 | + train_step = TrainStepModule(model, get_loss) |
| 275 | + tokens = torch.randint(0, 256, (2, 32), device=self.DEVICE) |
| 276 | + labels = torch.randint(0, 256, (2, 32), device=self.DEVICE) |
| 277 | + |
| 278 | + traced_result = trace_module(train_step, (tokens, labels)) |
| 279 | + |
| 280 | + # Collect seq_nr values from all call_function nodes |
| 281 | + seq_nrs = [] |
| 282 | + for node in traced_result.gm.graph.nodes: |
| 283 | + if node.op == "call_function" and "seq_nr" in node.meta: |
| 284 | + seq_nrs.append(node.meta["seq_nr"]) |
| 285 | + |
| 286 | + # There should be seq_nr values present (both fwd and bwd nodes) |
| 287 | + self.assertGreater(len(seq_nrs), 0, "No seq_nr metadata found on any node") |
| 288 | + |
| 289 | + # There should be duplicate seq_nrs (fwd and bwd nodes sharing seq_nr) |
| 290 | + counts = Counter(seq_nrs) |
| 291 | + shared = [nr for nr, cnt in counts.items() if cnt > 1] |
| 292 | + self.assertGreater( |
| 293 | + len(shared), |
| 294 | + 0, |
| 295 | + "Expected some seq_nr values shared between fwd and bwd nodes", |
| 296 | + ) |
| 297 | + |
| 298 | + def test_copy_fwd_metadata_propagates_custom(self): |
| 299 | + """Verify _copy_fwd_metadata_to_bw_nodes copies custom metadata to bwd nodes.""" |
| 300 | + model = SimpleMLP().to(device=self.DEVICE, dtype=self.DTYPE) |
| 301 | + |
| 302 | + # Use annotate to set custom metadata on forward nodes, then trace |
| 303 | + # with backward to verify it propagates |
| 304 | + train_step = TrainStepModule(model, get_loss) |
| 305 | + tokens = torch.randint(0, 256, (2, 32), device=self.DEVICE) |
| 306 | + labels = torch.randint(0, 256, (2, 32), device=self.DEVICE) |
| 307 | + |
| 308 | + traced_result = trace_module(train_step, (tokens, labels)) |
| 309 | + gm = traced_result.gm |
| 310 | + |
| 311 | + # Manually set custom metadata on the first fwd node for each seq_nr |
| 312 | + # to test that _copy_fwd_metadata_to_bw_nodes works |
| 313 | + seq_nr_first: dict[int, torch.fx.Node] = {} |
| 314 | + for node in gm.graph.nodes: |
| 315 | + if node.op == "call_function" and "seq_nr" in node.meta: |
| 316 | + seq_nr = node.meta["seq_nr"] |
| 317 | + if seq_nr not in seq_nr_first: |
| 318 | + seq_nr_first[seq_nr] = node |
| 319 | + node.meta["custom"] = {"test_key": "test_value"} |
| 320 | + |
| 321 | + # Run the copy pass again |
| 322 | + _copy_fwd_metadata_to_bw_nodes(gm) |
| 323 | + |
| 324 | + # Check that bwd nodes with shared seq_nr got the custom metadata |
| 325 | + for node in gm.graph.nodes: |
| 326 | + if node.op != "call_function" or "seq_nr" not in node.meta: |
| 327 | + continue |
| 328 | + seq_nr = node.meta["seq_nr"] |
| 329 | + if node is not seq_nr_first.get(seq_nr): |
| 330 | + # This is a backward node |
| 331 | + custom = node.meta.get("custom") |
| 332 | + self.assertIsNotNone( |
| 333 | + custom, |
| 334 | + f"Backward node {node.name} with seq_nr={seq_nr} missing custom metadata", |
| 335 | + ) |
| 336 | + self.assertEqual(custom.get("test_key"), "test_value") |
| 337 | + |
| 338 | + def test_patch_engine_restores_original(self): |
| 339 | + """Verify that _patch_engine_run_backward restores the original function.""" |
| 340 | + import torch.autograd |
| 341 | + import torch.autograd.graph |
| 342 | + |
| 343 | + orig_fn = torch.autograd.graph._engine_run_backward |
| 344 | + |
| 345 | + with _patch_engine_run_backward(): |
| 346 | + # Inside the context, it should be patched |
| 347 | + self.assertIsNot(torch.autograd.graph._engine_run_backward, orig_fn) |
| 348 | + self.assertIsNot(torch.autograd._engine_run_backward, orig_fn) |
| 349 | + |
| 350 | + # After the context, it should be restored |
| 351 | + self.assertIs(torch.autograd.graph._engine_run_backward, orig_fn) |
| 352 | + self.assertIs(torch.autograd._engine_run_backward, orig_fn) |
| 353 | + |
| 354 | + |
258 | 355 | if __name__ == "__main__": |
259 | 356 | unittest.main() |
0 commit comments