Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 100 additions & 53 deletions torchtitan/experiments/dion_optimizer/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,19 @@ def _create_muon_tasks(
continue

# Wrap hyperparameters in tensors for torch.compile
lr = torch.tensor(group["lr"])
mu = torch.tensor(group["mu"])
weight_decay = torch.tensor(group["weight_decay"])
epsilon = torch.tensor(group["epsilon"])
nesterov = group["nesterov"]
flatten = group["flatten"]
adjust_lr = group["adjust_lr"]
muon_update_args = dict(
lr=torch.tensor(group["lr"]),
momentum=torch.tensor(group["mu"]),
weight_decay=torch.tensor(group["weight_decay"]),
epsilon=torch.tensor(group["epsilon"]),
nesterov=group["nesterov"],
flatten=group["flatten"],
adjust_lr=group["adjust_lr"],
device_rank=self._device_rank,
world_size=self._world_size,
process_group=self._process_group,
newton_schulz_func=self._newton_schulz_func,
)

# Create batches of parameters of size self._world_size
for params in create_param_batches(
Expand All @@ -259,9 +265,12 @@ def _create_muon_tasks(
states = [self._get_or_initialize_state(p, algo_name) for p in params]
momentums = [s["momentum"] for s in states]

# Get sharding dimension
# Get sharding state for DTensor
is_batch_sharded = False
is_matrix_sharded = False
sharded_mesh_dim = None
sharded_tensor_dim = None

if isinstance(params[0], DTensor):
if not isinstance(self._distributed_mesh, DeviceMesh):
raise RuntimeError(
Expand All @@ -275,7 +284,23 @@ def _create_muon_tasks(
for i, p in enumerate(params[0].placements)
if p.is_shard() and params[0].device_mesh.size(i) > 1
]

# If we don't flatten 3D matrices, we can ignore shard placements along batch dimensions
# Only keep placements that shard one of the two matrix dimensions
if not group["flatten"]:
matrix_dims = {params[0].ndim - 1, params[0].ndim - 2}
is_batch_sharded = any(
p.dim not in matrix_dims for _, p in shard_placements
)
shard_placements = [
(i, p) for i, p in shard_placements if p.dim in matrix_dims
]

# Check that we have no more than 1 sharded matrix dimension
# Note that non-flattened 3D tensors can have additional sharded batch dimensions
# Flattened 3D tensors are limited to one sharded dimension out of all dimensions
if len(shard_placements) == 1:
is_matrix_sharded = True
sharded_mesh_dim = shard_placements[0][0]
sharded_tensor_dim = shard_placements[0][1].dim
elif len(shard_placements) > 1:
Expand All @@ -290,28 +315,35 @@ def _create_muon_tasks(
!= self._process_group
):
raise RuntimeError(
f"Got DTensor sharded over mesh dimension {sharded_mesh_dim} different from the optimizer's device mesh"
f"Got DTensor sharded over mesh dimension {sharded_mesh_dim} different from the optimizer's device mesh. "
f"DTensor has mesh: {params[0].device_mesh}, placements: {params[0].placements}, but optimizer was created with mesh: {self._distributed_mesh}."
)

yield AsyncTask(
muon_update_batch_async(
X=pad_batch(params, self._world_size),
G=pad_batch(gradients, self._world_size),
M=pad_batch(momentums, self._world_size),
lr=lr,
momentum=mu,
weight_decay=weight_decay,
epsilon=epsilon,
nesterov=nesterov,
flatten=flatten,
adjust_lr=adjust_lr,
device_rank=self._device_rank,
world_size=self._world_size,
shard_dim=sharded_tensor_dim,
process_group=self._process_group,
newton_schulz_func=self._newton_schulz_func,
# Special case for 3D tensors sharded along batch dimension
# As long as matrix dimensions are not sharded, each device will have whole matrices
# Each device already has different matrices of the batch, so we can't parallelize further
if is_batch_sharded and not is_matrix_sharded:
for x, g, m in zip(params, gradients, momentums):
yield AsyncTask(
muon_update_batch_async(
X=[x],
G=[g],
M=[m],
shard_dim=None, # No sharded matrix dim
**muon_update_args,
)
)
# Otherwise, we parallelize the Muon update across devices
else:
yield AsyncTask(
muon_update_batch_async(
X=pad_batch(params, self._world_size),
G=pad_batch(gradients, self._world_size),
M=pad_batch(momentums, self._world_size),
shard_dim=sharded_tensor_dim,
**muon_update_args,
)
)
)

def _create_lion_tasks(
self,
Expand Down Expand Up @@ -419,9 +451,6 @@ def muon_update_batch_async(

assert len(X) == len(G)
assert len(X) == len(M)
assert len(X) == world_size

# Expert parameter tracking (logging removed for cleaner output)

# Update momentum and compute the inputs for orthogonalization
U = muon_update_pre_orthogonalize(
Expand All @@ -435,6 +464,7 @@ def muon_update_batch_async(
if shard_dim is not None:
# Use all-to-all to transform from a batch of shards to a single whole matrix
# https://www.essential.ai/blog/infra
assert len(X) == world_size, "Batch size must equal world size"
assert (
process_group is not None
), "process_group must be provided for sharded DTensors"
Expand Down Expand Up @@ -477,9 +507,12 @@ def muon_update_batch_async(
yield
work.wait()

else:
# Matrices are not sharded, so we can directly orthogonalize
# Get a single matrix corresponding to this device
# Matrices are not sharded, so we can distribute the batch across different devices
# Get a single matrix of the batch corresponding to this device
elif len(U) > 1:
assert len(U) == world_size, "Batch size must equal world size"
assert process_group is not None

single_matrix = U[device_rank]
assert not isinstance(single_matrix, DTensor)

Expand All @@ -490,30 +523,36 @@ def muon_update_batch_async(
epsilon=epsilon,
)

if process_group is not None and process_group.size() > 1:
# Allocate empty tensors to receive updates from other devices
U = [torch.empty_like(u) for u in U]
# Allocate empty tensors to receive updates from other devices
U = [torch.empty_like(u) for u in U]

# All gather orthogonalized results from other devices into buffer
work = dist.all_gather(
U, single_matrix.contiguous(), group=process_group, async_op=True
)
yield
work.wait()
# All gather orthogonalized results from other devices into buffer
work = dist.all_gather(
U, single_matrix.contiguous(), group=process_group, async_op=True
)
yield
work.wait()

else:
# Single GPU case, no need to gather
assert world_size == 1
U = [single_matrix]
# Single matrix with no sharded dimension. This happens in 2 cases:
# - Running on a single GPU
# - 3D+ tensors sharded along a batch dimension (whole matrices per device)
else:
assert len(U) == 1
U[0] = muon_update_newton_schulz(
U[0],
newton_schulz_func=newton_schulz_func,
flatten=flatten,
epsilon=epsilon,
)

# Compute scaled learning rate
# Do this before to_local(X) because we use the full tensor shape, not the shard shape
if adjust_lr is None:
adjusted_lr = lr
elif adjust_lr == "spectral_norm":
adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape)
adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape, flatten=flatten)
elif adjust_lr == "rms_norm":
adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape)
adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape, flatten=flatten)
else:
raise ValueError(f"Unknown adjust_lr value: {adjust_lr}")

Expand Down Expand Up @@ -634,19 +673,27 @@ def muon_update_newton_schulz(
return newton_schulz_func(X, epsilon=epsilon).reshape(original_shape)


def adjust_lr_rms_norm(lr, param_shape):
def adjust_lr_rms_norm(lr, param_shape, flatten):
# Adjust learning rate for constant element-wise RMS norm
# https://arxiv.org/abs/2502.16982
A, B = param_shape[:2]
adjusted_ratio = 0.2 * math.sqrt(max(A, B))
if flatten:
fan_out = param_shape[0]
fan_in = math.prod(param_shape[1:])
else:
fan_out, fan_in = param_shape[-2:]
adjusted_ratio = 0.2 * math.sqrt(max(fan_out, fan_in))
adjusted_lr = lr * adjusted_ratio
return adjusted_lr


def adjust_lr_spectral_norm(lr, param_shape):
def adjust_lr_spectral_norm(lr, param_shape, flatten):
# Adjust from spectral norm 1 to RMS operator norm 1
# https://arxiv.org/abs/2310.17813
fan_out, fan_in = param_shape[:2]
if flatten:
fan_out = param_shape[0]
fan_in = math.prod(param_shape[1:])
else:
fan_out, fan_in = param_shape[-2:]
adjusted_lr = lr * math.sqrt(fan_out / fan_in)
return adjusted_lr

Expand Down
65 changes: 40 additions & 25 deletions torchtitan/experiments/dion_optimizer/parameter_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch
import torch.nn as nn
from torch.distributed.tensor import DTensor

from torchtitan.tools.logging import logger

Expand Down Expand Up @@ -41,6 +42,7 @@ def create_parameter_groups(
"routing": [],
"expert": [],
}
dtensor_info = {}

for model in model_parts:
# Group parameters by type
Expand All @@ -60,25 +62,32 @@ def create_parameter_groups(
# Add parameter name as attribute for debug logging
param._param_name = name

# Log parameter placement
if isinstance(param, DTensor):
dtensor_info[name] = (
param.device_mesh,
param.placements,
param.to_local().shape,
)

# Check if this is an expert weight parameter
if is_expert_param(name, param, model):
expert_type = classify_expert_param(name, param)
param_stats["expert"].append((name, param.shape, expert_type))

# Expert weights can use either DION (for 2D matrices only) or a dedicated expert optimizer
expert_optimizer = getattr(dion_config, "expert_optimizer", None)

if expert_optimizer is not None:
# Use dedicated expert optimizer if specified
expert_params.append(param)
param_stats["expert"].append(
(name, param.shape, expert_type, expert_optimizer)
)
elif param.ndim >= 2 and not is_head_param(name, param, model):
# Use matrix algorithm for expert matrices
# Dion only supports 2D, but Muon supports 3D+ (with flattening)
algorithm_name = dion_config.algorithm.upper()

# Check if flatten is enabled (for Muon 3D+ tensor support)
flatten_enabled = getattr(dion_config, "flatten", False)

# Handle case-insensitive algorithm matching
is_muon = (
algorithm_name in ["MUON", "muon"]
Expand All @@ -89,12 +98,22 @@ def create_parameter_groups(
if is_muon and param.ndim >= 2:
# Use Muon for 2D+ tensors when algorithm is MUON, or always for 2D tensors
dion_params.append(param)
param_stats["expert"].append(
(name, param.shape, expert_type, "muon")
)
else:
# Dion doesn't support 3D+, fall back to scalar optimizer
scalar_params.append(param)
param_stats["expert"].append(
(name, param.shape, expert_type, "scalar")
)
else:
# Fall back to scalar optimizer for 1D expert parameters
scalar_params.append(param)
param_stats["expert"].append(
(name, param.shape, expert_type, "scalar")
)

continue

# Classify parameter based on shape and module type
Expand Down Expand Up @@ -156,6 +175,8 @@ def create_parameter_groups(
logger.info(f" - {name}: {shape}")

logger.info(f"Scalar parameters ({scalar_opt}): {len(param_stats['scalar'])}")
for name, shape in param_stats["scalar"]:
logger.info(f" - {name}: {shape}")

logger.info(
f"Embedding parameters ({embedding_opt}): {len(param_stats['embedding'])}"
Expand Down Expand Up @@ -184,34 +205,26 @@ def create_parameter_groups(
)
if expert_optimizer is not None:
logger.info(f"Expert optimizer configured: {expert_optimizer.upper()}")
for name, shape, expert_type in param_stats["expert"]:
logger.info(
f" ✓ EXPERT: {name} ({shape}) - {expert_type} → USING {expert_optimizer.upper()}"
)
else:
logger.info(
"Expert optimizer not configured - using default classification:"
)
for name, shape, expert_type in param_stats["expert"]:
# Check if this expert parameter actually uses the matrix algorithm
algorithm_name = dion_config.algorithm.upper()
is_muon = (
algorithm_name in ["MUON", "muon"]
or dion_config.algorithm.lower() == "muon"
)
flatten_enabled = getattr(dion_config, "flatten", False)

if (len(shape) == 2) or (is_muon and len(shape) >= 2):
logger.info(
f" ✓ EXPERT: {name} ({shape}) - {expert_type} → USING {algorithm_name}"
)
else:
logger.info(
f" ✓ EXPERT: {name} ({shape}) - {expert_type} → USING {scalar_opt}"
)
for name, shape, expert_type, expert_opt in param_stats["expert"]:
logger.info(
f" ✓ EXPERT: {name} ({shape}) - {expert_type} → USING {expert_opt.upper()}"
)
else:
logger.info("No expert weight parameters detected in this model")

logger.info("=" * 40)
logger.info("DTENSOR INFO")
logger.info("=" * 40)

for name, (device_mesh, placements, local_shape) in dtensor_info.items():
logger.info(
f"{name:>40}: device mesh={device_mesh}, placements={placements}, local shape={local_shape}"
)

logger.info("=" * 80)

return param_groups
Expand Down Expand Up @@ -287,7 +300,9 @@ def is_expert_param(name: str, param: torch.Tensor, model: nn.Module) -> bool:
".expert.", # Alternative expert pattern
"expert_", # Expert prefix
"moe.expert", # MoE expert pattern
"shared_expert", # DeepSeek shared experts
"shared_experts", # DeepSeek shared experts
"routed_expert", # DeepSeek routed experts
"routed_experts", # DeepSeek routed experts
".experts[", # Indexed expert pattern
".w1.", # Expert feed-forward weights (common in MoE)
Expand Down
Loading