Skip to content

Commit 902c95c

Browse files
Fix grad norm clipping for AutoP and dsv3 model init
1 parent 0a2107f commit 902c95c

File tree

2 files changed

+17
-5
lines changed
  • torchtitan

2 files changed

+17
-5
lines changed

torchtitan/distributed/utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -494,18 +494,25 @@ def _clip_grad_norm_with_ep(
494494
else:
495495
non_ep_params.append(p)
496496
non_ep_grads.append(p.grad)
497+
498+
# Either list can be empty depending on the parallelization strategy:
499+
# - In torchtitan with separate dense/sparse meshes, both lists are typically non-empty
500+
# - In autoparallel, all params may live on a single sparse mesh with "ep" dimension,
501+
# so non_ep_grads would be empty
502+
# - In PP + EP setups, certain PP ranks may only own EP or non-EP layers
497503
ep_grads_total_norm = torch.nn.utils.get_total_norm(
498504
ep_grads, norm_type, error_if_nonfinite, foreach
499505
)
500-
# ep_grads may be an empty list, in which case get_total_norm returns tensor(0.), a non-DTensor
501-
# This can occur in PP + EP setups where certain PP ranks only own non-EP layers, for instance.
506+
# get_total_norm returns tensor(0.) for empty list, which is a non-DTensor
502507
if isinstance(ep_grads_total_norm, DTensor):
503508
ep_grads_total_norm = ep_grads_total_norm.full_tensor()
504509

505-
# pyrefly: ignore [missing-attribute]
506510
non_ep_grads_total_norm = torch.nn.utils.get_total_norm(
507511
non_ep_grads, norm_type, error_if_nonfinite, foreach
508-
).full_tensor()
512+
)
513+
# get_total_norm returns tensor(0.) for empty list, which is a non-DTensor
514+
if isinstance(non_ep_grads_total_norm, DTensor):
515+
non_ep_grads_total_norm = non_ep_grads_total_norm.full_tensor()
509516

510517
if math.isinf(norm_type):
511518
total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm)

torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,9 @@
1515
# Need to share same base class with torchtitan models
1616
class DeepSeekV3Model(_DeepSeekV3Model, ModelProtocol):
1717
def __init__(self, model_args: DeepSeekV3ModelArgs):
18-
super().__init__(model_args)
18+
# Call _DeepSeekV3Model.__init__ which calls nn.Module.__init__
19+
# Note: We don't call ModelProtocol.__init__ separately because:
20+
# 1. nn.Module.__init__() is already called by _DeepSeekV3Model.__init__
21+
# 2. Calling ModelProtocol.__init__ after would reset all module state
22+
# (nn.Module.__init__ clears _modules, _parameters, etc.)
23+
_DeepSeekV3Model.__init__(self, model_args)

0 commit comments

Comments
 (0)