|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
| 7 | +import contextlib |
7 | 8 | import unittest |
8 | 9 | from collections import Counter |
9 | 10 |
|
10 | 11 | import torch |
11 | 12 | import torch.nn as nn |
| 13 | +from torch.testing._internal.common_fsdp import FSDPTest |
12 | 14 |
|
13 | 15 | from torchtitan.experiments.graph_trainer.make_fx_tracer import ( |
14 | 16 | _copy_fwd_metadata_to_bw_nodes, |
@@ -38,7 +40,7 @@ def forward(self, *args): |
38 | 40 | loss = self.loss_fn(logits, labels) |
39 | 41 | # Must look up params in forward (not __init__) so that |
40 | 42 | # _reparametrize_module's swapped parameters are captured during tracing. |
41 | | - params = [p for _, p in self.model.named_parameters(remove_duplicate=False)] |
| 43 | + params = list(self.model.parameters()) |
42 | 44 | grads = torch.autograd.grad(loss, params) |
43 | 45 | return [loss] + list(grads) |
44 | 46 |
|
@@ -352,5 +354,347 @@ def test_patch_engine_restores_original(self): |
352 | 354 | self.assertIs(torch.autograd._engine_run_backward, orig_fn) |
353 | 355 |
|
354 | 356 |
|
| 357 | +@contextlib.contextmanager |
| 358 | +def _use_raw_flex_attn(): |
| 359 | + """Swap the compiled flex_attention with the raw (uncompiled) version. |
| 360 | +
|
| 361 | + FlexAttentionWrapper uses torch.compile'd flex_attention by default. |
| 362 | + torch.compile inside make_fx tracing is not supported and raises: |
| 363 | + "Detected that you are using FX to symbolically trace a |
| 364 | + dynamo-optimized function." |
| 365 | + Using the raw version lets make_fx decompose flex_attention into |
| 366 | + plain aten ops (bmm, softmax, etc.) which trace correctly. |
| 367 | +
|
| 368 | + Note: make_fx(..., pre_dispatch=True) with raw flex_attention preserves |
| 369 | + it as a FlexAttentionHOP higher-order op in the graph instead of |
| 370 | + decomposing it, which is what torch.export also does. |
| 371 | + """ |
| 372 | + from torch.nn.attention.flex_attention import flex_attention as raw_flex_attention |
| 373 | + |
| 374 | + from torchtitan.models.common.attention import FlexAttentionWrapper |
| 375 | + |
| 376 | + original = FlexAttentionWrapper._compiled_flex_attn |
| 377 | + FlexAttentionWrapper._compiled_flex_attn = staticmethod(raw_flex_attention) |
| 378 | + try: |
| 379 | + yield |
| 380 | + finally: |
| 381 | + FlexAttentionWrapper._compiled_flex_attn = original |
| 382 | + |
| 383 | + |
| 384 | +@unittest.skipUnless(torch.cuda.is_available(), "CUDA required") |
| 385 | +class TestTraceModels(unittest.TestCase): |
| 386 | + DEVICE = "cuda" |
| 387 | + DTYPE = torch.float32 |
| 388 | + BATCH_SIZE = 2 |
| 389 | + SEQ_LEN = 128 |
| 390 | + NUM_STEPS = 5 |
| 391 | + LR = 1e-3 |
| 392 | + |
| 393 | + def setUp(self): |
| 394 | + torch.manual_seed(42) |
| 395 | + torch.use_deterministic_algorithms(True) |
| 396 | + |
| 397 | + def tearDown(self): |
| 398 | + torch.use_deterministic_algorithms(False) |
| 399 | + |
| 400 | + def _run_bitwise_test( |
| 401 | + self, |
| 402 | + model_ref, |
| 403 | + model_copy, |
| 404 | + fwd_args, |
| 405 | + labels, |
| 406 | + check_collective_ops=False, |
| 407 | + num_steps=5, |
| 408 | + lr=1e-3, |
| 409 | + ): |
| 410 | + train_step_ref = TrainStepModule(model_ref, get_loss) |
| 411 | + |
| 412 | + with _use_raw_flex_attn(): |
| 413 | + traced_result = trace_module(train_step_ref, (*fwd_args, labels)) |
| 414 | + |
| 415 | + if check_collective_ops: |
| 416 | + ag = sum( |
| 417 | + 1 |
| 418 | + for n in traced_result.gm.graph.nodes |
| 419 | + if "all_gather_into_tensor" in str(n.target) |
| 420 | + ) |
| 421 | + rs = sum( |
| 422 | + 1 |
| 423 | + for n in traced_result.gm.graph.nodes |
| 424 | + if "reduce_scatter_tensor" in str(n.target) |
| 425 | + ) |
| 426 | + self.assertTrue( |
| 427 | + ag > 0 and rs > 0, |
| 428 | + f"Expected collective ops in FSDP graph (ag={ag}, rs={rs})", |
| 429 | + ) |
| 430 | + |
| 431 | + opt_ref = torch.optim.Adam(model_ref.parameters(), lr=lr) |
| 432 | + opt_copy = torch.optim.Adam(model_copy.parameters(), lr=lr) |
| 433 | + |
| 434 | + for step in range(1, num_steps + 1): |
| 435 | + with _use_raw_flex_attn(): |
| 436 | + logits_ref = model_ref(*fwd_args) |
| 437 | + loss_ref = get_loss(logits_ref, labels) |
| 438 | + loss_ref.backward() |
| 439 | + grads_ref = [p.grad.clone() for p in model_ref.parameters()] |
| 440 | + opt_ref.step() |
| 441 | + opt_ref.zero_grad() |
| 442 | + |
| 443 | + train_step_copy = TrainStepModule(model_copy, get_loss) |
| 444 | + pab = _get_params_and_buffers(train_step_copy) |
| 445 | + wrapped = run_traced_module(traced_result, pab, (*fwd_args, labels)) |
| 446 | + loss_tr = wrapped[0] |
| 447 | + grads_tr = wrapped[1:] |
| 448 | + for p, g in zip(model_copy.parameters(), grads_tr, strict=True): |
| 449 | + p.grad = g |
| 450 | + opt_copy.step() |
| 451 | + opt_copy.zero_grad() |
| 452 | + |
| 453 | + self.assertTrue( |
| 454 | + torch.equal(loss_ref, loss_tr), f"Step {step}: loss mismatch" |
| 455 | + ) |
| 456 | + for gr, gt in zip(grads_ref, grads_tr, strict=True): |
| 457 | + self.assertTrue(torch.equal(gr, gt), f"Step {step}: grad mismatch") |
| 458 | + |
| 459 | + def _run_model_test(self, config_cls, model_config, use_attn_masks=False): |
| 460 | + vocab_size = model_config.vocab_size |
| 461 | + model_ref = create_model(config_cls, model_config, self.DEVICE, self.DTYPE) |
| 462 | + model_copy = create_model(config_cls, model_config, self.DEVICE, self.DTYPE) |
| 463 | + model_copy.load_state_dict(model_ref.state_dict()) |
| 464 | + tokens = torch.randint( |
| 465 | + 0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE |
| 466 | + ) |
| 467 | + labels = torch.randint( |
| 468 | + 0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE |
| 469 | + ) |
| 470 | + |
| 471 | + if use_attn_masks: |
| 472 | + from torchtitan.models.common.attention import ( |
| 473 | + create_attention_mask, |
| 474 | + get_causal_mask_mod, |
| 475 | + ) |
| 476 | + |
| 477 | + attn_masks = create_attention_mask( |
| 478 | + get_causal_mask_mod(), 1, None, self.SEQ_LEN, self.SEQ_LEN |
| 479 | + ) |
| 480 | + self._run_bitwise_test( |
| 481 | + model_ref, |
| 482 | + model_copy, |
| 483 | + (tokens, attn_masks), |
| 484 | + labels, |
| 485 | + num_steps=self.NUM_STEPS, |
| 486 | + lr=self.LR, |
| 487 | + ) |
| 488 | + return |
| 489 | + |
| 490 | + self._run_bitwise_test( |
| 491 | + model_ref, |
| 492 | + model_copy, |
| 493 | + (tokens,), |
| 494 | + labels, |
| 495 | + num_steps=self.NUM_STEPS, |
| 496 | + lr=self.LR, |
| 497 | + ) |
| 498 | + |
| 499 | + def test_llama3(self): |
| 500 | + from torchtitan.models.llama3 import llama3_configs, Llama3Model |
| 501 | + |
| 502 | + self._run_model_test(Llama3Model, llama3_configs["debugmodel"]) |
| 503 | + |
| 504 | + def test_qwen3(self): |
| 505 | + from torchtitan.models.qwen3 import qwen3_configs |
| 506 | + from torchtitan.models.qwen3.model import Qwen3Model |
| 507 | + |
| 508 | + self._run_model_test(Qwen3Model, qwen3_configs["debugmodel"]) |
| 509 | + |
| 510 | + def test_qwen3_moe(self): |
| 511 | + from torchtitan.models.qwen3 import qwen3_configs |
| 512 | + from torchtitan.models.qwen3.model import Qwen3Model |
| 513 | + |
| 514 | + self._run_model_test(Qwen3Model, qwen3_configs["debugmodel_moe"]) |
| 515 | + |
| 516 | + def test_deepseek_v3(self): |
| 517 | + from torchtitan.models.deepseek_v3 import deepseekv3_configs |
| 518 | + from torchtitan.models.deepseek_v3.model import DeepSeekV3Model |
| 519 | + |
| 520 | + self._run_model_test(DeepSeekV3Model, deepseekv3_configs["debugmodel"]) |
| 521 | + |
| 522 | + def test_llama4(self): |
| 523 | + from torchtitan.models.llama4 import llama4_configs |
| 524 | + from torchtitan.models.llama4.model import Llama4Model |
| 525 | + |
| 526 | + self._run_model_test( |
| 527 | + Llama4Model, llama4_configs["debugmodel"], use_attn_masks=True |
| 528 | + ) |
| 529 | + |
| 530 | + def test_gpt_oss(self): |
| 531 | + from torch.nn.attention.flex_attention import and_masks |
| 532 | + |
| 533 | + from torchtitan.models.common.attention import ( |
| 534 | + create_attention_mask, |
| 535 | + get_causal_mask_mod, |
| 536 | + get_sliding_window_mask_mod, |
| 537 | + ) |
| 538 | + from torchtitan.models.gpt_oss import gptoss_configs |
| 539 | + from torchtitan.models.gpt_oss.model import GptOssModel |
| 540 | + |
| 541 | + config = gptoss_configs["debugmodel"] |
| 542 | + vocab_size = config.vocab_size |
| 543 | + model_ref = create_model(GptOssModel, config, self.DEVICE, self.DTYPE) |
| 544 | + model_copy = create_model(GptOssModel, config, self.DEVICE, self.DTYPE) |
| 545 | + model_copy.load_state_dict(model_ref.state_dict()) |
| 546 | + tokens = torch.randint( |
| 547 | + 0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE |
| 548 | + ) |
| 549 | + labels = torch.randint( |
| 550 | + 0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE |
| 551 | + ) |
| 552 | + causal = get_causal_mask_mod() |
| 553 | + sw_size = config.layer.attention.sliding_window_size |
| 554 | + basic_mask = create_attention_mask(causal, 1, None, self.SEQ_LEN, self.SEQ_LEN) |
| 555 | + sliding_window_mask = create_attention_mask( |
| 556 | + and_masks(causal, get_sliding_window_mask_mod(sw_size)), |
| 557 | + 1, |
| 558 | + None, |
| 559 | + self.SEQ_LEN, |
| 560 | + self.SEQ_LEN, |
| 561 | + ) |
| 562 | + attn_masks = { |
| 563 | + "basic_mask": basic_mask, |
| 564 | + "sliding_window_mask": sliding_window_mask, |
| 565 | + } |
| 566 | + self._run_bitwise_test( |
| 567 | + model_ref, |
| 568 | + model_copy, |
| 569 | + (tokens, attn_masks), |
| 570 | + labels, |
| 571 | + num_steps=self.NUM_STEPS, |
| 572 | + lr=self.LR, |
| 573 | + ) |
| 574 | + |
| 575 | + |
| 576 | +class TestTraceFSDP(FSDPTest): |
| 577 | + @property |
| 578 | + def world_size(self): |
| 579 | + return min(torch.cuda.device_count(), 4) |
| 580 | + |
| 581 | + def _setup(self): |
| 582 | + from torchtitan.distributed import ParallelDims |
| 583 | + |
| 584 | + self.parallel_dims = ParallelDims( |
| 585 | + dp_shard=-1, |
| 586 | + dp_replicate=1, |
| 587 | + cp=1, |
| 588 | + tp=1, |
| 589 | + pp=1, |
| 590 | + ep=1, |
| 591 | + etp=1, |
| 592 | + world_size=self.world_size, |
| 593 | + ) |
| 594 | + |
| 595 | + def _run_fsdp_model_test(self, config_cls, model_config, use_attn_masks=False): |
| 596 | + from torchtitan.experiments.graph_trainer.simple_fsdp import data_parallel |
| 597 | + |
| 598 | + self._setup() |
| 599 | + fsdp_mesh = self.parallel_dims.get_mesh("fsdp") |
| 600 | + |
| 601 | + model_ref = create_model(config_cls, model_config, "cuda", torch.float32) |
| 602 | + model_copy = create_model(config_cls, model_config, "cuda", torch.float32) |
| 603 | + model_copy.load_state_dict(model_ref.state_dict()) |
| 604 | + data_parallel(model_ref, device_mesh=fsdp_mesh, mode="fully_shard") |
| 605 | + data_parallel(model_copy, device_mesh=fsdp_mesh, mode="fully_shard") |
| 606 | + |
| 607 | + vocab_size = model_config.vocab_size |
| 608 | + seq_len = 128 |
| 609 | + tokens = torch.randint(0, vocab_size, (2, seq_len), device="cuda") |
| 610 | + labels = torch.randint(0, vocab_size, (2, seq_len), device="cuda") |
| 611 | + |
| 612 | + if use_attn_masks: |
| 613 | + from torchtitan.models.common.attention import ( |
| 614 | + create_attention_mask, |
| 615 | + get_causal_mask_mod, |
| 616 | + ) |
| 617 | + |
| 618 | + attn_masks = create_attention_mask( |
| 619 | + get_causal_mask_mod(), 1, None, seq_len, seq_len |
| 620 | + ) |
| 621 | + fwd_args = (tokens, attn_masks) |
| 622 | + else: |
| 623 | + fwd_args = (tokens,) |
| 624 | + |
| 625 | + train_step_ref = TrainStepModule(model_ref, get_loss) |
| 626 | + |
| 627 | + with _use_raw_flex_attn(): |
| 628 | + traced_result = trace_module(train_step_ref, (*fwd_args, labels)) |
| 629 | + |
| 630 | + ag = sum( |
| 631 | + 1 |
| 632 | + for n in traced_result.gm.graph.nodes |
| 633 | + if "all_gather_into_tensor" in str(n.target) |
| 634 | + ) |
| 635 | + rs = sum( |
| 636 | + 1 |
| 637 | + for n in traced_result.gm.graph.nodes |
| 638 | + if "reduce_scatter_tensor" in str(n.target) |
| 639 | + ) |
| 640 | + self.assertTrue( |
| 641 | + ag > 0 and rs > 0, |
| 642 | + f"Expected collective ops in FSDP graph (ag={ag}, rs={rs})", |
| 643 | + ) |
| 644 | + |
| 645 | + opt_ref = torch.optim.Adam(model_ref.parameters(), lr=1e-3) |
| 646 | + opt_copy = torch.optim.Adam(model_copy.parameters(), lr=1e-3) |
| 647 | + |
| 648 | + for step in range(1, 6): |
| 649 | + with _use_raw_flex_attn(): |
| 650 | + logits_ref = model_ref(*fwd_args) |
| 651 | + loss_ref = get_loss(logits_ref, labels) |
| 652 | + loss_ref.backward() |
| 653 | + grads_ref = [p.grad.clone() for p in model_ref.parameters()] |
| 654 | + opt_ref.step() |
| 655 | + opt_ref.zero_grad() |
| 656 | + |
| 657 | + train_step_copy = TrainStepModule(model_copy, get_loss) |
| 658 | + pab = _get_params_and_buffers(train_step_copy) |
| 659 | + wrapped = run_traced_module(traced_result, pab, (*fwd_args, labels)) |
| 660 | + loss_tr = wrapped[0] |
| 661 | + grads_tr = wrapped[1:] |
| 662 | + for p, g in zip(model_copy.parameters(), grads_tr, strict=True): |
| 663 | + p.grad = g |
| 664 | + opt_copy.step() |
| 665 | + opt_copy.zero_grad() |
| 666 | + |
| 667 | + self.assertTrue( |
| 668 | + torch.equal(loss_ref, loss_tr), f"Step {step}: loss mismatch" |
| 669 | + ) |
| 670 | + for gr, gt in zip(grads_ref, grads_tr, strict=True): |
| 671 | + self.assertTrue(torch.equal(gr, gt), f"Step {step}: grad mismatch") |
| 672 | + |
| 673 | + def test_llama3_fsdp(self): |
| 674 | + from torchtitan.models.llama3 import llama3_configs, Llama3Model |
| 675 | + |
| 676 | + self._run_fsdp_model_test(Llama3Model, llama3_configs["debugmodel"]) |
| 677 | + |
| 678 | + def test_qwen3_fsdp(self): |
| 679 | + from torchtitan.models.qwen3 import qwen3_configs |
| 680 | + from torchtitan.models.qwen3.model import Qwen3Model |
| 681 | + |
| 682 | + self._run_fsdp_model_test(Qwen3Model, qwen3_configs["debugmodel"]) |
| 683 | + |
| 684 | + def test_deepseek_v3_fsdp(self): |
| 685 | + from torchtitan.models.deepseek_v3 import deepseekv3_configs |
| 686 | + from torchtitan.models.deepseek_v3.model import DeepSeekV3Model |
| 687 | + |
| 688 | + self._run_fsdp_model_test(DeepSeekV3Model, deepseekv3_configs["debugmodel"]) |
| 689 | + |
| 690 | + def test_llama4_fsdp(self): |
| 691 | + from torchtitan.models.llama4 import llama4_configs |
| 692 | + from torchtitan.models.llama4.model import Llama4Model |
| 693 | + |
| 694 | + self._run_fsdp_model_test( |
| 695 | + Llama4Model, llama4_configs["debugmodel"], use_attn_masks=True |
| 696 | + ) |
| 697 | + |
| 698 | + |
355 | 699 | if __name__ == "__main__": |
356 | 700 | unittest.main() |
0 commit comments