diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index b15b620ae1..c39e9be528 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -879,19 +879,13 @@ def _update_weights(self, state: State) -> Optional[torch.Tensor]: total_grad_norm = None # gradient norm clipping if clip_grad_norm: - if _is_fsdp_module(module): - if isinstance(module, FSDP): - with get_timing_context( - state, f"{self.__class__.__name__}.clip_grad_norm" - ): - total_grad_norm = module.clip_grad_norm_( - max_norm=clip_grad_norm - ) - else: - raise RuntimeError( - "Composable FSDP clip_grad_norm is not yet implemented: https://github.com/pytorch/pytorch/issues/97271" - ) + if isinstance(module, FSDP): + with get_timing_context( + state, f"{self.__class__.__name__}.clip_grad_norm" + ): + total_grad_norm = module.clip_grad_norm_(max_norm=clip_grad_norm) else: + # strategy=None, DDP, and FSDP2 will work with this with get_timing_context( state, f"{self.__class__.__name__}.clip_grad_norm" ):