Skip to content

Commit be1ed63

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
update clip grad norm for fsdp2
Reviewed By: vdogaru Differential Revision: D70268715 fbshipit-source-id: 5f252597a1ff11f61ed692850b8946bfd3758906
1 parent 722b387 commit be1ed63

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

torchtnt/framework/auto_unit.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -879,19 +879,13 @@ def _update_weights(self, state: State) -> Optional[torch.Tensor]:
879879
total_grad_norm = None
880880
# gradient norm clipping
881881
if clip_grad_norm:
882-
if _is_fsdp_module(module):
883-
if isinstance(module, FSDP):
884-
with get_timing_context(
885-
state, f"{self.__class__.__name__}.clip_grad_norm"
886-
):
887-
total_grad_norm = module.clip_grad_norm_(
888-
max_norm=clip_grad_norm
889-
)
890-
else:
891-
raise RuntimeError(
892-
"Composable FSDP clip_grad_norm is not yet implemented: https://github.com/pytorch/pytorch/issues/97271"
893-
)
882+
if isinstance(module, FSDP):
883+
with get_timing_context(
884+
state, f"{self.__class__.__name__}.clip_grad_norm"
885+
):
886+
total_grad_norm = module.clip_grad_norm_(max_norm=clip_grad_norm)
894887
else:
888+
# strategy=None, DDP, and FSDP2 will work with this
895889
with get_timing_context(
896890
state, f"{self.__class__.__name__}.clip_grad_norm"
897891
):

0 commit comments

Comments
 (0)