|
14 | 14 | from dataclasses import dataclass
|
15 | 15 | from typing import (
|
16 | 16 | Any,
|
| 17 | + cast, |
17 | 18 | ContextManager,
|
18 | 19 | Generic,
|
19 | 20 | Iterator,
|
|
26 | 27 |
|
27 | 28 | import torch
|
28 | 29 | from pyre_extensions import none_throws
|
29 |
| -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| 30 | +from torch.distributed.fsdp import FSDPModule, FullyShardedDataParallel as FSDP |
30 | 31 | from torch.nn.parallel import DistributedDataParallel as DDP
|
31 | 32 | from torch.optim.swa_utils import SWALR
|
32 | 33 | from torchtnt.framework._unit_utils import _step_requires_iterator
|
|
42 | 43 | GradScaler,
|
43 | 44 | )
|
44 | 45 | from torchtnt.utils.prepare_module import (
|
| 46 | + _is_fsdp2_module, |
45 | 47 | _is_fsdp_module,
|
46 | 48 | ActivationCheckpointParams,
|
47 | 49 | FSDPStrategy,
|
@@ -672,12 +674,19 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
|
672 | 674 | # https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync
|
673 | 675 | # https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.no_sync
|
674 | 676 | maybe_no_sync = (
|
675 |
| - # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. |
676 | 677 | module.no_sync()
|
677 | 678 | if not should_update_weights
|
678 |
| - and (isinstance(module, DDP) or _is_fsdp_module(module)) |
| 679 | + and (isinstance(module, DDP) or isinstance(module, FSDP)) |
679 | 680 | else contextlib.nullcontext()
|
680 | 681 | )
|
| 682 | + # fsdp2 has separate way of disabling gradient sync |
| 683 | + if _is_fsdp2_module(module): |
| 684 | + if not should_update_weights: |
| 685 | + cast(FSDPModule, module).set_requires_gradient_sync(False) |
| 686 | + elif should_update_weights and self.gradient_accumulation_steps > 1: |
| 687 | + # if gradient accumulation is used and it's time to update weights, |
| 688 | + # we need to re-enable gradient sync |
| 689 | + cast(FSDPModule, module).set_requires_gradient_sync(True) |
681 | 690 |
|
682 | 691 | # if detect_anomaly is true, run forward and backward pass in detect_anomaly context
|
683 | 692 | detect_anomaly = self.detect_anomaly
|
|
0 commit comments