Commit 9abc9aa
fix: use grad div factor when fsdp_degree=1 (pytorch#167178)
`fully_shard`'s `gradient_divide_factor` isn't currently respected when the sharding degree = 1. This PR ensures the division factor applies also in this case.
This is a bit of an edge case, but it arises in `torchtitan`, e.g. with expert parallelism and `ep_degree=world_size` we still wrap the routed experts in `fully_shard` because:
1) It lets us take advantage of its mixed-precision mechanisms.
2) [A specific gradient_divide_factor is needed for correctness](https://github.com/pytorch/torchtitan/blob/176498cd4edd4d80e95959a618279681f8295f4c/torchtitan/models/llama4/infra/parallelize.py?plain=1#L364-L369)
This PR ensures correctness in the `reduce_scatter_group.size()==1` case.
Reproducer and sample failures are in the [gist here](https://gist.github.ibm.com/goon/f67e7559284cc2d322faff1ac59fe382). The net effect is that the EP grads are too-large by a factor of the world size in the case described above. I checked that the proposed fix makes these tests pass.
I guess I should add a test for this, too?
Pull Request resolved: pytorch#167178
Approved by: https://github.com/weifengpy1 parent 789240b commit 9abc9aa
File tree
2 files changed
+43
-16
lines changed- test/distributed/_composable/fsdp
- torch/distributed/fsdp/_fully_shard
2 files changed
+43
-16
lines changedLines changed: 31 additions & 9 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
428 | 428 | | |
429 | 429 | | |
430 | 430 | | |
431 | | - | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
432 | 439 | | |
433 | 440 | | |
434 | 441 | | |
435 | 442 | | |
436 | 443 | | |
437 | 444 | | |
438 | 445 | | |
439 | | - | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
440 | 449 | | |
441 | 450 | | |
442 | 451 | | |
443 | 452 | | |
444 | 453 | | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
445 | 458 | | |
446 | 459 | | |
447 | | - | |
448 | | - | |
| 460 | + | |
| 461 | + | |
449 | 462 | | |
450 | | - | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
451 | 471 | | |
452 | 472 | | |
453 | 473 | | |
| |||
456 | 476 | | |
457 | 477 | | |
458 | 478 | | |
459 | | - | |
| 479 | + | |
| 480 | + | |
460 | 481 | | |
461 | 482 | | |
462 | 483 | | |
463 | 484 | | |
464 | 485 | | |
465 | | - | |
466 | | - | |
467 | 486 | | |
| 487 | + | |
468 | 488 | | |
| 489 | + | |
| 490 | + | |
469 | 491 | | |
470 | 492 | | |
471 | 493 | | |
| |||
484 | 506 | | |
485 | 507 | | |
486 | 508 | | |
487 | | - | |
| 509 | + | |
488 | 510 | | |
489 | 511 | | |
490 | 512 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
547 | 547 | | |
548 | 548 | | |
549 | 549 | | |
550 | | - | |
551 | | - | |
| 550 | + | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
552 | 556 | | |
553 | 557 | | |
554 | 558 | | |
| |||
721 | 725 | | |
722 | 726 | | |
723 | 727 | | |
724 | | - | |
725 | | - | |
726 | | - | |
727 | 728 | | |
728 | | - | |
| 729 | + | |
729 | 730 | | |
730 | 731 | | |
731 | 732 | | |
732 | 733 | | |
733 | 734 | | |
| 735 | + | |
| 736 | + | |
734 | 737 | | |
735 | 738 | | |
736 | | - | |
| 739 | + | |
737 | 740 | | |
| 741 | + | |
| 742 | + | |
738 | 743 | | |
739 | 744 | | |
740 | 745 | | |
| |||
0 commit comments