@@ -494,18 +494,25 @@ def _clip_grad_norm_with_ep(
494494 else :
495495 non_ep_params .append (p )
496496 non_ep_grads .append (p .grad )
497+
498+ # Either list can be empty depending on the parallelization strategy:
499+ # - In torchtitan with separate dense/sparse meshes, both lists are typically non-empty
500+ # - In autoparallel, all params may live on a single sparse mesh with "ep" dimension,
501+ # so non_ep_grads would be empty
502+ # - In PP + EP setups, certain PP ranks may only own EP or non-EP layers
497503 ep_grads_total_norm = torch .nn .utils .get_total_norm (
498504 ep_grads , norm_type , error_if_nonfinite , foreach
499505 )
500- # ep_grads may be an empty list, in which case get_total_norm returns tensor(0.), a non-DTensor
501- # This can occur in PP + EP setups where certain PP ranks only own non-EP layers, for instance.
506+ # get_total_norm returns tensor(0.) for empty list, which is a non-DTensor
502507 if isinstance (ep_grads_total_norm , DTensor ):
503508 ep_grads_total_norm = ep_grads_total_norm .full_tensor ()
504509
505- # pyrefly: ignore [missing-attribute]
506510 non_ep_grads_total_norm = torch .nn .utils .get_total_norm (
507511 non_ep_grads , norm_type , error_if_nonfinite , foreach
508- ).full_tensor ()
512+ )
513+ # get_total_norm returns tensor(0.) for empty list, which is a non-DTensor
514+ if isinstance (non_ep_grads_total_norm , DTensor ):
515+ non_ep_grads_total_norm = non_ep_grads_total_norm .full_tensor ()
509516
510517 if math .isinf (norm_type ):
511518 total_norm = torch .maximum (ep_grads_total_norm , non_ep_grads_total_norm )
0 commit comments