Skip to content

Commit 4059cc4

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
support grad accumulation for fsdp2 (#981)
Summary: Pull Request resolved: #981 Reviewed By: diego-urgell Differential Revision: D70720659 fbshipit-source-id: 84f2af7c7c06c9f730518ebdb17aed3d8cc31a55
1 parent 99c5cda commit 4059cc4

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

tests/framework/test_auto_unit.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,36 @@ def test_detect_anomaly_disabled_with_torch_compile(self) -> None:
750750

751751
self.assertIsNone(auto_unit.detect_anomaly)
752752

753+
@patch("torchtnt.framework.auto_unit._is_fsdp2_module", return_value=True)
754+
def test_gradient_accumulation_fsdp2(self, _) -> None:
755+
auto_unit = DummyAutoUnit(
756+
module=torch.nn.Linear(1, 1),
757+
gradient_accumulation_steps=3,
758+
)
759+
760+
# Dynamically add a mocked method as an attribute to module
761+
fsdp_module_mock = MagicMock()
762+
auto_unit.module.set_requires_gradient_sync = fsdp_module_mock
763+
auto_unit._is_last_batch = False
764+
765+
state = get_dummy_train_state()
766+
767+
# Simulate train steps
768+
for step in range(4):
769+
# Call train_step to trigger set_requires_gradient_sync
770+
auto_unit.train_step(state, (torch.rand(1, 1), torch.rand(1, 1)))
771+
772+
# Check if set_requires_gradient_sync is called with the correct boolean
773+
if (step + 1) % auto_unit.gradient_accumulation_steps == 0:
774+
auto_unit.module.set_requires_gradient_sync.assert_called_with(True)
775+
else:
776+
auto_unit.module.set_requires_gradient_sync.assert_called_with(False)
777+
778+
# Reset mock for the next iteration
779+
auto_unit.module.set_requires_gradient_sync.reset_mock()
780+
781+
auto_unit.train_progress.increment_step()
782+
753783

754784
Batch = Tuple[torch.Tensor, torch.Tensor]
755785

torchtnt/framework/auto_unit.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from dataclasses import dataclass
1515
from typing import (
1616
Any,
17+
cast,
1718
ContextManager,
1819
Generic,
1920
Iterator,
@@ -26,7 +27,7 @@
2627

2728
import torch
2829
from pyre_extensions import none_throws
29-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
30+
from torch.distributed.fsdp import FSDPModule, FullyShardedDataParallel as FSDP
3031
from torch.nn.parallel import DistributedDataParallel as DDP
3132
from torch.optim.swa_utils import SWALR
3233
from torchtnt.framework._unit_utils import _step_requires_iterator
@@ -42,6 +43,7 @@
4243
GradScaler,
4344
)
4445
from torchtnt.utils.prepare_module import (
46+
_is_fsdp2_module,
4547
_is_fsdp_module,
4648
ActivationCheckpointParams,
4749
FSDPStrategy,
@@ -672,12 +674,19 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
672674
# https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync
673675
# https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.no_sync
674676
maybe_no_sync = (
675-
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
676677
module.no_sync()
677678
if not should_update_weights
678-
and (isinstance(module, DDP) or _is_fsdp_module(module))
679+
and (isinstance(module, DDP) or isinstance(module, FSDP))
679680
else contextlib.nullcontext()
680681
)
682+
# fsdp2 has separate way of disabling gradient sync
683+
if _is_fsdp2_module(module):
684+
if not should_update_weights:
685+
cast(FSDPModule, module).set_requires_gradient_sync(False)
686+
elif should_update_weights and self.gradient_accumulation_steps > 1:
687+
# if gradient accumulation is used and it's time to update weights,
688+
# we need to re-enable gradient sync
689+
cast(FSDPModule, module).set_requires_gradient_sync(True)
681690

682691
# if detect_anomaly is true, run forward and backward pass in detect_anomaly context
683692
detect_anomaly = self.detect_anomaly

0 commit comments

Comments
 (0)