File tree Expand file tree Collapse file tree 1 file changed +4
-0
lines changed Expand file tree Collapse file tree 1 file changed +4
-0
lines changed Original file line number Diff line number Diff line change 28
28
import torch
29
29
from pyre_extensions import none_throws
30
30
from torch .distributed .fsdp import FSDPModule , FullyShardedDataParallel as FSDP
31
+ from torch .distributed .tensor import DTensor
31
32
from torch .nn .parallel import DistributedDataParallel as DDP
32
33
from torch .optim .swa_utils import SWALR
33
34
from torchtnt .framework ._unit_utils import _step_requires_iterator
@@ -900,6 +901,9 @@ def _update_weights(self, state: State) -> Optional[torch.Tensor]:
900
901
parameters = module .parameters (),
901
902
max_norm = clip_grad_norm ,
902
903
)
904
+ # If sharded, collect the DTensor here
905
+ if isinstance (total_grad_norm , DTensor ):
906
+ total_grad_norm = total_grad_norm .full_tensor ()
903
907
904
908
# gradient value clipping
905
909
if clip_grad_value :
You can’t perform that action at this time.
0 commit comments