File tree Expand file tree Collapse file tree 1 file changed +6
-12
lines changed Expand file tree Collapse file tree 1 file changed +6
-12
lines changed Original file line number Diff line number Diff line change @@ -879,19 +879,13 @@ def _update_weights(self, state: State) -> Optional[torch.Tensor]:
879
879
total_grad_norm = None
880
880
# gradient norm clipping
881
881
if clip_grad_norm :
882
- if _is_fsdp_module (module ):
883
- if isinstance (module , FSDP ):
884
- with get_timing_context (
885
- state , f"{ self .__class__ .__name__ } .clip_grad_norm"
886
- ):
887
- total_grad_norm = module .clip_grad_norm_ (
888
- max_norm = clip_grad_norm
889
- )
890
- else :
891
- raise RuntimeError (
892
- "Composable FSDP clip_grad_norm is not yet implemented: https://github.com/pytorch/pytorch/issues/97271"
893
- )
882
+ if isinstance (module , FSDP ):
883
+ with get_timing_context (
884
+ state , f"{ self .__class__ .__name__ } .clip_grad_norm"
885
+ ):
886
+ total_grad_norm = module .clip_grad_norm_ (max_norm = clip_grad_norm )
894
887
else :
888
+ # strategy=None, DDP, and FSDP2 will work with this
895
889
with get_timing_context (
896
890
state , f"{ self .__class__ .__name__ } .clip_grad_norm"
897
891
):
You can’t perform that action at this time.
0 commit comments