|
14 | 14 | from dataclasses import dataclass |
15 | 15 | from typing import ( |
16 | 16 | Any, |
| 17 | + cast, |
17 | 18 | ContextManager, |
18 | 19 | Generic, |
19 | 20 | Iterator, |
|
26 | 27 |
|
27 | 28 | import torch |
28 | 29 | from pyre_extensions import none_throws |
29 | | -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| 30 | +from torch.distributed.fsdp import FSDPModule, FullyShardedDataParallel as FSDP |
30 | 31 | from torch.nn.parallel import DistributedDataParallel as DDP |
31 | 32 | from torch.optim.swa_utils import SWALR |
32 | 33 | from torchtnt.framework._unit_utils import _step_requires_iterator |
|
42 | 43 | GradScaler, |
43 | 44 | ) |
44 | 45 | from torchtnt.utils.prepare_module import ( |
| 46 | + _is_fsdp2_module, |
45 | 47 | _is_fsdp_module, |
46 | 48 | ActivationCheckpointParams, |
47 | 49 | FSDPStrategy, |
@@ -672,12 +674,19 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]: |
672 | 674 | # https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync |
673 | 675 | # https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.no_sync |
674 | 676 | maybe_no_sync = ( |
675 | | - # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. |
676 | 677 | module.no_sync() |
677 | 678 | if not should_update_weights |
678 | | - and (isinstance(module, DDP) or _is_fsdp_module(module)) |
| 679 | + and (isinstance(module, DDP) or isinstance(module, FSDP)) |
679 | 680 | else contextlib.nullcontext() |
680 | 681 | ) |
| 682 | + # fsdp2 has separate way of disabling gradient sync |
| 683 | + if _is_fsdp2_module(module): |
| 684 | + if not should_update_weights: |
| 685 | + cast(FSDPModule, module).set_requires_gradient_sync(False) |
| 686 | + elif should_update_weights and self.gradient_accumulation_steps > 1: |
| 687 | + # if gradient accumulation is used and it's time to update weights, |
| 688 | + # we need to re-enable gradient sync |
| 689 | + cast(FSDPModule, module).set_requires_gradient_sync(True) |
681 | 690 |
|
682 | 691 | # if detect_anomaly is true, run forward and backward pass in detect_anomaly context |
683 | 692 | detect_anomaly = self.detect_anomaly |
|
0 commit comments