Skip to content

Commit d501c48

Browse files
RushabhMfacebook-github-bot
authored andcommitted
Collecting DTensor in case of sharding (#1021)
Summary: Pull Request resolved: #1021 Reviewed By: JKSenthil Differential Revision: D79264984 fbshipit-source-id: 9c18a394e3662b0d62269c6c93b890058fe251db
1 parent 3e43021 commit d501c48

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torchtnt/framework/auto_unit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import torch
2929
from pyre_extensions import none_throws
3030
from torch.distributed.fsdp import FSDPModule, FullyShardedDataParallel as FSDP
31+
from torch.distributed.tensor import DTensor
3132
from torch.nn.parallel import DistributedDataParallel as DDP
3233
from torch.optim.swa_utils import SWALR
3334
from torchtnt.framework._unit_utils import _step_requires_iterator
@@ -900,6 +901,9 @@ def _update_weights(self, state: State) -> Optional[torch.Tensor]:
900901
parameters=module.parameters(),
901902
max_norm=clip_grad_norm,
902903
)
904+
# If sharded, collect the DTensor here
905+
if isinstance(total_grad_norm, DTensor):
906+
total_grad_norm = total_grad_norm.full_tensor()
903907

904908
# gradient value clipping
905909
if clip_grad_value:

0 commit comments

Comments
 (0)