From 0586ffd729ec0e5bd354f0bd5134477690dbf619 Mon Sep 17 00:00:00 2001 From: wang55 Date: Mon, 25 Aug 2025 03:26:03 +0200 Subject: [PATCH] init scion --- torchtitan/components/optimizer.py | 16 +- torchtitan/config/job_config.py | 30 + torchtitan/distributed/utils.py | 10 +- .../experiments/distributed_scion/__init__.py | 12 + .../clean_distributed_scion.py | 842 +++++++++++ .../distributed_scion/distributed_scion.py | 1254 +++++++++++++++++ .../distributed_scion/muon_utils.py | 119 ++ .../distributed_scion/naive_param_norm.py | 238 ++++ .../distributed_scion/norm_helper.py | 217 +++ .../train_configs/debug_model.toml | 109 ++ .../experiments/distributed_scion/utils.py | 206 +++ 11 files changed, 3049 insertions(+), 4 deletions(-) create mode 100644 torchtitan/experiments/distributed_scion/__init__.py create mode 100644 torchtitan/experiments/distributed_scion/clean_distributed_scion.py create mode 100644 torchtitan/experiments/distributed_scion/distributed_scion.py create mode 100644 torchtitan/experiments/distributed_scion/muon_utils.py create mode 100644 torchtitan/experiments/distributed_scion/naive_param_norm.py create mode 100644 torchtitan/experiments/distributed_scion/norm_helper.py create mode 100644 torchtitan/experiments/distributed_scion/train_configs/debug_model.toml create mode 100644 torchtitan/experiments/distributed_scion/utils.py diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index d3e962810..568620ef4 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -21,6 +21,7 @@ from torchtitan.components.ft import FTManager, has_torchft from torchtitan.config import Optimizer as OptimizerConfig from torchtitan.distributed import ParallelDims +from torchtitan.experiments import distributed_scion __all__ = [ "OptimizersContainer", @@ -75,7 +76,12 @@ def __init__( self.optimizers = [] self.model_parts = model_parts for model in self.model_parts: - params = [p for p in model.parameters() if p.requires_grad] + if issubclass(optimizer_cls, (distributed_scion.DistributedScion)): + params, optimizer_kwargs = distributed_scion.create_scion_param_groups( + model, optimizer_kwargs + ) + else: + params = [p for p in model.parameters() if p.requires_grad] self.optimizers.append(optimizer_cls(params, **optimizer_kwargs)) all_params.extend(params) self._validate_length(len(self.model_parts)) @@ -302,9 +308,17 @@ def build_optimizers( "foreach": foreach, } + if name in ["DistributedScion"]: + optimizer_kwargs = ( + distributed_scion.create_scion_optimizer_kwargs_from_optimizer_config( + optimizer_config, parallel_dims + ) + ) + optimizer_classes = { "Adam": torch.optim.Adam, "AdamW": torch.optim.AdamW, + "DistributedScion": distributed_scion.DistributedScion, } if name not in optimizer_classes: raise NotImplementedError(f"Optimizer {name} not added.") diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index a2247aa21..0d102a40b 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -69,6 +69,9 @@ class Metrics: enable_wandb: bool = False """Whether to log metrics to Weights & Biases""" + log_norm_freq: int = -1 + """How often to log norms in iterations""" + @dataclass class Model: @@ -122,6 +125,33 @@ class Optimizer: weight_decay: float = 0.1 """Weight decay to use""" + mup_width_multiplier: float = 1.0 + """ + Width multiplier for the model to apply μP scaling (only used + for Adam/Muon-based optimizers). + """ + + is_light: bool = False + """Whether to use Scion's light (memory-saving) version""" + + norm_factor: str = "spectral" + """Which norm factor to use""" + + zeropower_backend: str = "newtonschulz5" + "Which `zeropower_backend` to use." + + backend_steps: int = 5 + """Number of steps for the Scion backend""" + + momentum: float = 0.95 + """Scion momentum to use""" + + nesterov: bool = False + """Whether to use Nesterov momentum in Scion""" + + extra_splits_rules: list[dict[str, Any]] | None = None + """Extra parameter group splitting rules for Scion optimizers""" + implementation: Literal["for-loop", "foreach", "fused"] = "fused" """ Specify which optimizer implementation to use: diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 74d310dfc..4565f5e72 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -388,7 +388,8 @@ def clip_grad_norm_( dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) total_norm **= 1.0 / norm_type - torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) + if max_norm > 0: + torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) return total_norm @@ -444,7 +445,10 @@ def _clip_grad_norm_with_ep( dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) total_norm **= 1.0 / norm_type - torch.nn.utils.clip_grads_with_norm_(ep_params, max_norm, total_norm, foreach) - torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach) + if max_norm > 0: + torch.nn.utils.clip_grads_with_norm_(ep_params, max_norm, total_norm, foreach) + torch.nn.utils.clip_grads_with_norm_( + non_ep_params, max_norm, total_norm, foreach + ) return total_norm diff --git a/torchtitan/experiments/distributed_scion/__init__.py b/torchtitan/experiments/distributed_scion/__init__.py new file mode 100644 index 000000000..ab2849620 --- /dev/null +++ b/torchtitan/experiments/distributed_scion/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .distributed_scion import DistributedScion # noqa: F401 +from .utils import ( # noqa: F401 # noqa: F401 # noqa: F401 + create_scion_optimizer_kwargs_from_optimizer_config, + create_scion_param_groups, + remove_orig_mod_and_weight_for_p_name, +) diff --git a/torchtitan/experiments/distributed_scion/clean_distributed_scion.py b/torchtitan/experiments/distributed_scion/clean_distributed_scion.py new file mode 100644 index 000000000..526a1c4e8 --- /dev/null +++ b/torchtitan/experiments/distributed_scion/clean_distributed_scion.py @@ -0,0 +1,842 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from enum import Enum +from functools import partial + +import torch +import torch.distributed as dist +import torch.distributed.tensor +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard + +from torchtitan.tools.logging import logger + +from .muon_utils import zeropower_backends +from .norm_helper import NORM_FUNCTIONS + +__all__ = [ + "DistributedScion", +] + + +class ParamType(Enum): + DDP = 0 + FSDP = 1 + Expert = 2 + Unknown = 3 + + +def get_param_type(p, fsdp_enabled, expert_enabled): + """ + We can aggressively assume that the param is FSDP-Sharded + """ + if p.grad is None: + return ParamType.Unknown + if not fsdp_enabled and not expert_enabled and isinstance(p, torch.Tensor): + return ParamType.DDP + if p.ndim == 3: + return ParamType.Expert + elif fsdp_enabled: + return ParamType.FSDP + else: + return ParamType.Unknown + + +def tp_axis(placements: tuple, tp_enabled: bool = False) -> int | None: + """ + Return the index in `placements` that belongs to *tensor-parallel* (TP). + + Heuristics (PyTorch-TP default layouts): + 1. Row-parallel weights ⇒ `_StridedShard` ⟶ that axis is TP. + 2. Col-parallel weights ⇒ `Shard(dim != 0)` ⟶ that axis is TP + (FSDP shards dim-0, so a non-zero dim means TP). + """ + # rule 1 – row-parallel + for i, p in enumerate(placements): + if isinstance(p, _StridedShard): + return i + + # rule 2 – col-parallel + for i, p in enumerate(placements): + if isinstance(p, Shard) and p.dim != 0: + return i + + # this is a special case, We do TP only + if tp_enabled and len(placements) == 1: + if isinstance(placements[0], Shard): + return 0 + return None # could not infer + + +def gather_tp_shard(tensor, tp_group, tp_world_size, original_placements): + # TP is used, we need to gather the TP-shard params first + tp_mesh_dim = tp_axis(original_placements, True) + assert tp_mesh_dim is not None, "something wrong here" + shard_dim = original_placements[tp_mesh_dim].dim + + output_tensors = [torch.empty_like(tensor) for _ in range(tp_world_size)] + dist.all_gather(output_tensors, tensor, group=tp_group) + return torch.cat(output_tensors, dim=shard_dim) + + +def calculate_shard_shape(shape, rank, world_size): + full = shape[0] + splits = torch.arange(full).chunk(world_size) + if rank >= len(splits): + dim0 = 0 + else: + dim0 = len(splits[rank]) + + return (dim0, *shape[1:]) + + +class DistributedScion(torch.optim.Optimizer): + def __init__( + self, + params, + is_light, + weight_decay, + lr, + momentum, + nesterov, + eps, + norm_factor, + backend, + backend_steps, + parallel_dims, + communication_dtype=torch.bfloat16, + extra_reduce_for_HSDP=False, + experts_weights_layout="G-D_out-D_in", + ): + self.need_to_calculate_norm = False + self.norms_to_log: list[str] = list(NORM_FUNCTIONS.keys()) + self.norms_at_current_step = {} + self.extra_reduce_for_HSDP = False + self.log_parameters_types = True + + defaults = dict( + lr=lr, + momentum=momentum, + weight_decay=weight_decay, + nesterov=nesterov, + eps=eps, + norm_factor=norm_factor, + backend=backend, + backend_steps=backend_steps, + ) + self.is_light = is_light + + is_unconstrained = weight_decay == 0 + + self.world_mesh = parallel_dims.world_mesh + + self.fsdp_enabled = parallel_dims.fsdp_enabled + self.expert_enabled = parallel_dims.ep_enabled + self.dp_replicate_enabled = parallel_dims.dp_replicate_enabled + self.tp_enabled = parallel_dims.tp_enabled + + # this is used to ensure only the DP or FSDP rank 0 will have norms + self.is_dp_rank_0 = dist.get_rank(self.world_mesh["dp_cp"].get_group()) == 0 + + assert experts_weights_layout in [ + "G-D_in-D_out", + "G-D_out-D_in", + ], f"Unknown experts weights layout: {experts_weights_layout}" + self.experts_need_transpose = experts_weights_layout == "G-D_in-D_out" + self.extra_reduce_for_HSDP = extra_reduce_for_HSDP + + logger.info( + f"Distributed Scion optimizer " + f"(is_light={self.is_light}, is_unconstrained={is_unconstrained}) " + f"is enabled with world_mesh={self.world_mesh} | fsdp_enabled={self.fsdp_enabled} | " + f"EP={self.expert_enabled} | TP={self.tp_enabled} | DP={self.dp_replicate_enabled}" + ) + + super().__init__(params, defaults) + if self.is_light: + # Initialize state + self._store_grads_in_state() + # Do not pass `self` through syntactic sugar. We need the + # argument to not be populated. + self.register_state_dict_pre_hook( + type(self)._store_grads_in_state, + ) + self.register_load_state_dict_post_hook( + type(self)._load_grads_from_state, + ) + + self.communication_dtype = communication_dtype + self.groups_info = {} + self.parameters_to_groups = {} + for group_idx, group in enumerate(self.param_groups): + lr = group["lr"] + nesterov = group["nesterov"] + momentum = group["momentum"] + wd = group["weight_decay"] + param_kwargs = { + "eps": group["eps"], + "norm_factor": group["norm_factor"], + "zeropower_backend": group["backend"], + "backend_steps": group["backend_steps"], + } + self.groups_info[group_idx] = [lr, nesterov, momentum, wd, param_kwargs] + for param in group["params"]: + self.parameters_to_groups[id(param)] = group_idx + + if self.is_light and nesterov: + raise RuntimeError( + "Nesterov momentum is not supported for Scion's light mode. " + "Please set nesterov=False." + ) + + def calculate_norm_at_next_step(self, norms_to_log: list[str]): + self.need_to_calculate_norm = True + self.norms_to_log = norms_to_log + self.norms_at_current_step = {} + + def get_norms_at_current_step(self): + if self.is_dp_rank_0: + return self.norms_at_current_step + else: + return {} + + def zero_grad(self, *args, **kwargs): + if self.is_light: + pass + else: + super().zero_grad(*args, **kwargs) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + scale_params, scale_param_names = [], [] + embed_params, embed_param_names = [], [] + ddp_params, ddp_param_names = [], [] + fsdp_params, fsdp_param_names = [], [] + expert_params, expert_param_names = [], [] + + for group_idx, group in enumerate(self.param_groups): + # we should update self.groups_info here incase we have LR and momentum scheduler + # We can also optionally do norm_factor and backend scheduler if we want to + lr = group["lr"] + nesterov = group["nesterov"] + momentum = group["momentum"] + wd = group["weight_decay"] + param_kwargs = { + "eps": group["eps"], + "norm_factor": group["norm_factor"], + "zeropower_backend": group["backend"], + "backend_steps": group["backend_steps"], + } + self.groups_info[group_idx] = [lr, nesterov, momentum, wd, param_kwargs] + + for p_name, p in zip(group["param_names"], group["params"]): + norm_factor = group["norm_factor"] + backend = group["backend"] + is_embed_norm = norm_factor.startswith( + "embed" + ) or norm_factor.startswith("unembed") + + if p.numel() == 1: + assert ( + backend == "identity" + ), "scale params must use identity backend" + assert ( + norm_factor == "sign" + ), "scale params must use sign norm factor" + scale_params.append(p) + scale_param_names.append(p_name) + continue + + if backend == "identity" and is_embed_norm: + # for these Row/Col-wise norm, there is no need to gather the gradient + embed_params.append(p) + embed_param_names.append(p_name) + continue + + param_type = get_param_type(p, self.fsdp_enabled, self.expert_enabled) + if param_type == ParamType.DDP: + ddp_params.append(p) + ddp_param_names.append(p_name) + elif param_type == ParamType.FSDP: + fsdp_params.append(p) + fsdp_param_names.append(p_name) + elif param_type == ParamType.Expert: + expert_params.append(p) + expert_param_names.append(p_name) + elif param_type == ParamType.Unknown: + logger.warning( + f"Unknown param type: {p_name}, p.shape {p.shape}, grad is None[?] " + f"{p.grad is None}, the optimizer will skip this param" + ) + # raise ValueError(f"Unknown param type: {p_name}") + continue + else: + raise ValueError("param_type") + + # Sort fsdp_params and their names together + fsdp_pairs = list(zip(fsdp_params, fsdp_param_names)) + fsdp_pairs.sort(key=lambda x: x[0].numel(), reverse=True) + fsdp_params, fsdp_param_names = zip(*fsdp_pairs) if fsdp_pairs else ([], []) + # Sort expert_params and their names together + expert_pairs = list(zip(expert_params, expert_param_names)) + expert_pairs.sort(key=lambda x: (x[0].numel(), x[0].shape[1]), reverse=True) + expert_params, expert_param_names = ( + zip(*expert_pairs) if expert_pairs else ([], []) + ) + if self.log_parameters_types: + # only log once + logger.info( + f"fsdp_params: {len(fsdp_params)} | expert_params: {len(expert_params)} | " + f"ddp_params: {len(ddp_params)} | embed_params: {len(embed_params)} | " + f"scale_params: {len(scale_params)}" + ) + self.log_parameters_types = False + + """ + We could merge `embed_params` and `expert_params` into one list. + The diff is, we are sure expert_params have bunch of 2D full-matrixs + But we might need to gather the `embed_params` to 2D full-matrixs + if we wanna to get the norm of the gradient. + """ + self.step_scalar(scale_params, scale_param_names) + self.step_embedding(embed_params, embed_param_names) + self.step_experts(expert_params, expert_param_names) + self.step_ddp(ddp_params, ddp_param_names) + self.step_fsdp(fsdp_params, fsdp_param_names) + + # reset the flag for the next step + self.need_to_calculate_norm = False + return loss + + @torch.no_grad() + def lmo( + self, + g, + eps, + norm_factor, + zeropower_backend, + backend_steps, + ): + g = g.to_local() if isinstance(g, DTensor) else g + + # NB: make sure this function does not modify the grad inplace + # since it is also called during the log of gradients + def _lmo_for_2d_tensor(g, need_transpose=False): + g = g if not need_transpose else g.transpose(0, 1) + g = zeropower_backends[zeropower_backend](g, steps=backend_steps, eps=eps) + g = self.normalise_grad(g, norm_factor=norm_factor, eps=eps) + return g if not need_transpose else g.transpose(0, 1) + + if g.ndim == 2: + g = _lmo_for_2d_tensor(g, need_transpose=False) + elif g.ndim == 3: + if g.shape[0] > 0: + # When world_size [fsdp x EP] > Total number of experts, + # some ranks may have 0 experts that shape will be [0, d-in, d-out] + # We should return the original grad here and **do not** do stack + g = torch.stack( + [ + _lmo_for_2d_tensor( + g[i], need_transpose=self.experts_need_transpose + ) + for i in range(g.shape[0]) + ], + dim=0, + ) + else: + pass + elif g.ndim == 1: + if zeropower_backend != "identity": + g_diag = torch.diag_embed(g).contiguous() + result_diag = _lmo_for_2d_tensor(g_diag) + g = result_diag.diagonal().contiguous() + else: + g = _lmo_for_2d_tensor(g) + + # TODO(JSC): JUST HARD CODE IT TO USE 'identity' backend and 'bias_rms' norm_factor for + # now until we add regex to extra the norm's weights + # zeropower_backend = "identity" + # norm_factor = "bias_rms" + # g = _lmo_for_2d_tensor(g) + + else: + raise ValueError(f"Unknown grad shape: {g.shape}") + + return g + + @torch.no_grad() + def normalise_grad(self, g, norm_factor, eps): + if norm_factor == "spectral": + g = g * (g.size(0) / g.size(1)) ** 0.5 + elif norm_factor == "image_spectral": + g = g * max((g.size(0) / g.size(1)) ** 0.5, 1) + elif norm_factor.startswith("embed"): + # NB: here assume shape [vocab_size, embed_dim] + rms_values = torch.sqrt(g.pow(2).sum(axis=1, keepdim=True)) + g = g / (rms_values + eps) + if norm_factor == "embed_linear": + g = g * g.size(1) + elif norm_factor == "embed_sqrt": + g = g * g.size(1) ** 0.5 + else: + raise ValueError(f"Unknown norm_factor: {norm_factor}") + elif norm_factor.startswith("unembed"): + rms_values = torch.sqrt(g.pow(2).sum(axis=1, keepdim=True)) + g = g / (rms_values + eps) + if norm_factor == "unembed_linear": + g = g / g.size(1) + elif norm_factor == "unembed_sqrt": + g = g / g.size(1) ** 0.5 + else: + raise ValueError(f"Unknown norm_factor: {norm_factor}") + elif norm_factor == "sign": + g = torch.sign(g) + elif norm_factor == "bias_rms": + rms_value = torch.sqrt(g.pow(2).mean()) + g = g / (rms_value + eps) + elif norm_factor == "none": + pass + else: + raise ValueError(f"Unknown norm_factor: {norm_factor}") + + return g + + def __getstate__(self): + self._store_grads_in_state() + return super().__getstate__() + + def __setstate__(self, state): + super().__setstate__(state) + self._load_grads_from_state() + + def _store_grads_in_state(self): + for group in self.param_groups: + for param in group["params"]: + if isinstance(param, torch.Tensor) and param.grad is not None: + self.state.setdefault(param, {})["grad_state"] = param.grad + + def _load_grads_from_state(self): + for param, state in self.state.items(): + if "grad_state" in state: + param.grad = state["grad_state"] + elif isinstance(param, torch.Tensor): + param.grad = None + + def update_bucket_params(self, params, updates, start_idx, end_idx, tp_group=None): + # TODO(JSC): we could maybe use tesnor update rather than for-loop here + # can be helpful for FSDP and EP params + for idx_in_bucket in range(start_idx, end_idx): + shift = idx_in_bucket - start_idx + p = params[idx_in_bucket] + u = updates[shift] + lr, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + + if wd != 0: + # p.data.mul_(1 - wd*lr) + p.mul_(1 - wd * lr) + + if isinstance(p, DTensor) and self.tp_enabled: + original_placements = p.placements + tp_mesh_dim = tp_axis(original_placements, p.shape == u.shape) + + if isinstance(p, DTensor): + if tp_group is None or tp_mesh_dim is None: + p.to_local().add_(u, alpha=-lr) + else: + tp_rank = tp_group.rank() + tp_sharded_dim = original_placements[tp_mesh_dim].dim + chunk_size = p.to_local().shape[tp_sharded_dim] + start_offset = tp_rank * chunk_size + + slicer = [slice(None)] * u.dim() + slicer[tp_sharded_dim] = slice( + start_offset, start_offset + chunk_size + ) + u_sliced = u[slicer] + p.to_local().add_(u_sliced, alpha=-lr) + else: + p.add_(u, alpha=-lr) + + if momentum != 1 and self.is_light and p.grad is not None: + p.grad.mul_(1 - momentum) + + def step_scalar( + self, + scalar_params, + scalar_param_names, + skip_update=False, + ): + if len(scalar_params) == 0: + return {} + + for param_idx in range(len(scalar_params)): + p = scalar_params[param_idx] + lr, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + g = self.get_momentum_or_grad( + p, momentum, nesterov, update_buffer=True, gather_to_local=False + ) + g = g.to_local() if isinstance(g, DTensor) else g + + # the lmo of scalar is just sign + u = torch.sign(g) + + if not skip_update: + self.update_bucket_params([p], [u], 0, 1) + + def step_embedding( + self, + embed_params, + embed_param_names, + skip_update=False, + ): + if len(embed_params) == 0: + return {} + + tp_group = None + # if self.dp_replicate_enabled: + # dp_replicate_group = self.world_mesh["dp_replicate"].get_group() + # else: + # dp_replicate_group = None + + if self.tp_enabled: + tp_group = self.world_mesh["tp"].get_group() + + for param_idx in range(len(embed_params)): + p = embed_params[param_idx] + lr, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + g = self.get_momentum_or_grad( + p, momentum, nesterov, update_buffer=True, gather_to_local=False + ) + + u = self.lmo(g, **param_kwargs) + + ######################################################### + # # As of we use norm for Embedding, maybe we should not do Reduce here + # if ( + # dp_replicate_group is not None + # and self.extra_reduce_for_HSDP + # and self.fsdp_enabled + # ): + # dist.all_reduce(u, group=dp_replicate_group, op=dist.ReduceOp.AVG) + # dist.barrier(group=dp_replicate_group) + if not skip_update: + self.update_bucket_params([p], [u], 0, 1, tp_group=tp_group) + + def step_experts( + self, + expert_params, + expert_param_names, + skip_update=False, + ): + if len(expert_params) == 0: + return {} + + device = expert_params[0].device + fsdp_group = self.world_mesh["dp_shard_cp"].get_group() + world_size = dist.get_world_size(fsdp_group) + local_rank = dist.get_rank(fsdp_group) + ep_per_rank = math.ceil(expert_params[0].shape[0] / world_size) + + # each rank will process `len(expert_params) * ep_per_rank` experts + # each expert will have `self.norms_to_log` norms + # so each rank will have `len(expert_params) * ep_per_rank * len(self.norms_to_log)` + # norms + # globally, its [[g0-ep0, g0-ep1, g0-ep2, ...], [g1-ep0, g1-ep1, g1-ep2, ...], ...] on each + # rank + + for param_idx in range(len(expert_params)): + p = expert_params[param_idx] + lr, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + g = self.get_momentum_or_grad( + p, momentum, nesterov, update_buffer=True, gather_to_local=False + ) + u = self.lmo(g, **param_kwargs) + + if not skip_update: + self.update_bucket_params([p], [u], 0, 1) + + def step_ddp( + self, + ddp_params, + ddp_param_names, + skip_update=False, + ): + # Either we do DDP + # or we do TP, there is no [DDP + TP] case but for safety we add sevel checks + # if len(ddp_params) == 0: + # return {} + + tp_group, dp_replicate_group = None, None + + rank = 0 + bucket_size = world_size = 1 + total_buckets = len(ddp_params) + + if self.dp_replicate_enabled: + dp_replicate_group = self.world_mesh["dp_replicate"].get_group() + world_size = dp_replicate_group.size() + rank = dp_replicate_group.rank() + + bucket_size = world_size + total_buckets = math.ceil(len(ddp_params) / bucket_size) + + if self.tp_enabled: + tp_group = self.world_mesh["tp"].get_group() + tp_world_size = dist.get_world_size(group=tp_group) + + device = ddp_params[0].device if len(ddp_params) > 0 else torch.device("cuda") + cast_dtype = self.communication_dtype + zero_tensor = partial(torch.zeros, dtype=cast_dtype, device=device) + + # for DDP, we need to first update the buffer + for param_idx in range(len(ddp_params)): + p = ddp_params[param_idx] + _, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + g = self.get_momentum_or_grad( + p, momentum, nesterov, update_buffer=True, gather_to_local=False + ) + + # then we do scion stuff + for bucket_idx in range(total_buckets): + start_idx = bucket_idx * bucket_size + end_idx = min(start_idx + bucket_size, len(ddp_params)) + current_rank_idx = start_idx + rank + if current_rank_idx < len(ddp_params): + p = ddp_params[current_rank_idx] + # Step 1: Get the gradient + _, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + g = self.get_momentum_or_grad( + p, momentum, nesterov, update_buffer=False, gather_to_local=False + ) + if isinstance(g, DTensor) and self.tp_enabled: + g = gather_tp_shard( + g.to_local(), tp_group, tp_world_size, g.placements + ).to(dtype=cast_dtype) + + else: + # To avoid idle stream, we pad the last rank + p = ddp_params[end_idx - 1] + g = zero_tensor(p.shape) + _, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + + # step 2: lmo + u = self.lmo(g, **param_kwargs) + + if not skip_update: + # Step 3: FOR DDP, we do all-gather + if self.dp_replicate_enabled: + # only gather params when we doing DDP + BUCKET + gather_lists = [None] * world_size + for i in range(world_size): + param_idx = start_idx + i + if i == rank or param_idx >= len(ddp_params): + gather_lists[i] = u.to(dtype=cast_dtype) + elif param_idx < len(ddp_params): + p = ddp_params[start_idx + i] + gather_lists[i] = zero_tensor(p.shape) + dist.all_gather( + gather_lists, u.to(dtype=cast_dtype), group=dp_replicate_group + ) + if self.tp_enabled: + # only if DP+TP we need to barrier here other-wise its automatically synced + dist.barrier(group=dp_replicate_group) + else: + # other wise (TP only), dp world_size is 1 + gather_lists = [u.to(dtype=cast_dtype)] + + # Step 4: Update the parameters + self.update_bucket_params( + ddp_params, gather_lists, start_idx, end_idx, tp_group=tp_group + ) + + def step_fsdp( + self, + fsdp_params, + fsdp_param_names, + skip_update=False, + ): + if len(fsdp_params) == 0: + return {} + tp_group, dp_replicate_group = None, None + """ + To make FSDP+DP works, we lets step_fsdp work on each dp_replicate separately. + Hence, we only care about the world size inside the dp_replicate. + """ + + # due to the werid implementation of parallel_dims.py (upstream) + # here we should use `dp_shard_cp` rather then `dp_shard` as of + # CP is also part of the dp_shard + fsdp_group = self.world_mesh["dp_shard_cp"].get_group() + + if self.dp_replicate_enabled: + dp_replicate_group = self.world_mesh["dp_replicate"].get_group() + + if self.tp_enabled: + tp_group = self.world_mesh["tp"].get_group() + tp_world_size = dist.get_world_size(group=tp_group) + + world_size = dist.get_world_size(fsdp_group) + rank = dist.get_rank(fsdp_group) + + # @ THIS IS A HACK + bucket_size = world_size + total_buckets = math.ceil(len(fsdp_params) / bucket_size) + + device = fsdp_params[0].device + cast_dtype = self.communication_dtype + zero_tensor = partial(torch.empty, dtype=cast_dtype, device=device) + + # Process each bucket + for bucket_idx in range(total_buckets): + start_idx = bucket_idx * bucket_size + end_idx = min(start_idx + bucket_size, len(fsdp_params)) + + # Step 1: Prepare data for first all_to_all + grads_send_list, send_shapes = [], [] + target_shape, param_kwargs = None, None + + for rank_idx in range(world_size): + current_rank_idx = start_idx + rank_idx + + if current_rank_idx < len(fsdp_params): + p = fsdp_params[current_rank_idx] + _, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + + g = self.get_momentum_or_grad( + p, + momentum, + nesterov, + update_buffer=True, + gather_to_local=False, + ) + + original_placements = g.placements + tp_mesh_dim = tp_axis(original_placements) + if tp_group is not None and tp_mesh_dim is not None: + # the reason we need `tp_mesh_dim` is we want a flexible solution + # that Attention go TP and MLP go EP + g = gather_tp_shard( + g.to_local(), tp_group, tp_world_size, original_placements + ).to(dtype=cast_dtype) + else: + g = g.to_local().to(dtype=cast_dtype) + + # Save the shape info for this parameter + if rank == rank_idx: + target_shape = p.shape + else: + # Use a dummy shape for parameters beyond our range + p = fsdp_params[end_idx - 1] + g = zero_tensor(p.to_local().shape) + + grads_send_list.append(g) + send_shapes.append(g.shape) + + # Make sure target_shape is initialized + # (trigger by the padding of the last ranks) + if target_shape is None and end_idx > 0: + target_shape = fsdp_params[end_idx - 1].shape + param_kwargs = self.groups_info[ + self.parameters_to_groups[id(fsdp_params[end_idx - 1])] + ][-1] + + recv_shapes = [ + calculate_shard_shape(target_shape, rank_idx, world_size) + for rank_idx in range(world_size) + ] + recv_list = [zero_tensor(shape) for shape in recv_shapes] + # Step 3: First all_to_all - using ASYNC version + dist.barrier() + dist.all_to_all(recv_list, grads_send_list, group=fsdp_group) + # Step 5: Concatenate received gradients along dimension 0 and perform NS5 + # All tensors in recv_list should have the same dimensions except for dim 0 + + full_g = torch.cat(recv_list, dim=0) + u = self.lmo(full_g, **param_kwargs) + dist.barrier(group=fsdp_group) + + if dp_replicate_group is not None and self.extra_reduce_for_HSDP: + dist.all_reduce(u, group=dp_replicate_group, op=dist.ReduceOp.AVG) + dist.barrier(group=dp_replicate_group) + # in case of FSDP+DP, we can do a All-Reduce here sync the grads + if not skip_update: + # Step 6: Split the processed tensor back for second all_to_all + split_sizes = [shape[0] for shape in recv_shapes] + + grads_send_list = list(torch.split(u, split_sizes, dim=0)) + recv_list = [zero_tensor(shape) for shape in send_shapes] + # Step 8: Second all_to_all - using ASYNC version + dist.all_to_all(recv_list, grads_send_list, group=fsdp_group) + del grads_send_list + # Step 10: Update parameters using the results + self.update_bucket_params( + fsdp_params, + recv_list, + start_idx, + end_idx, + tp_group=tp_group, + ) + + @torch.no_grad() + def get_momentum_or_grad( + self, p, momentum, nesterov, update_buffer=False, gather_to_local=False + ): + g = p.grad + if g is None or not p.requires_grad: + return None + + use_momentum = momentum > 0 and momentum < 1 + + if not self.is_light and use_momentum: + state = self.state[p] + if "momentum_buffer" not in state.keys(): + if update_buffer: + state["momentum_buffer"] = torch.zeros_like(g) + else: + """ + When you using DDP + Dist-muon,you might trieer an error here. + Because in the optimizer.log you try to log all gradient's norm. + But for DDP + Dist-muon, each rank only has a part of the gradient. + + -- + For debug, you can return None here. + """ + raise ValueError( + "Momentum buffer not found in optimizer state. " + "Please check if the optimizer is initialized correctly." + ) + buf = state["momentum_buffer"] + if update_buffer: + buf.mul_(1 - momentum).add_(g, alpha=momentum) + else: + buf = buf.mul(1 - momentum).add(g, alpha=momentum) + g = buf if not nesterov else buf.mul(1 - momentum).add(g, alpha=momentum) + + if gather_to_local and isinstance(g, DTensor): + g = g.redistribute(placements=[Replicate()] * g.device_mesh.ndim).to_local() + return g diff --git a/torchtitan/experiments/distributed_scion/distributed_scion.py b/torchtitan/experiments/distributed_scion/distributed_scion.py new file mode 100644 index 000000000..c90a1d528 --- /dev/null +++ b/torchtitan/experiments/distributed_scion/distributed_scion.py @@ -0,0 +1,1254 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from enum import Enum +from functools import partial + +import torch +import torch.distributed as dist +import torch.distributed.tensor +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard + +from torchtitan.tools.logging import logger + +from .muon_utils import zeropower_backends +from .norm_helper import calculate_norm, NORM_FUNCTIONS +from .utils import remove_orig_mod_and_weight_for_p_name + +__all__ = [ + "DistributedScion", +] + + +class ParamType(Enum): + DDP = 0 + FSDP = 1 + Expert = 2 + Unknown = 3 + + +def get_param_type(p, fsdp_enabled, expert_enabled): + """ + We can aggressively assume that the param is FSDP-Sharded + """ + if p.grad is None: + return ParamType.Unknown + if not fsdp_enabled and not expert_enabled and isinstance(p, torch.Tensor): + return ParamType.DDP + if p.ndim == 3: + return ParamType.Expert + elif fsdp_enabled: + return ParamType.FSDP + else: + return ParamType.Unknown + + +def tp_axis(placements: tuple, tp_enabled: bool = False) -> int | None: + """ + Return the index in `placements` that belongs to *tensor-parallel* (TP). + + Heuristics (PyTorch-TP default layouts): + 1. Row-parallel weights ⇒ `_StridedShard` ⟶ that axis is TP. + 2. Col-parallel weights ⇒ `Shard(dim != 0)` ⟶ that axis is TP + (FSDP shards dim-0, so a non-zero dim means TP). + """ + # rule 1 – row-parallel + for i, p in enumerate(placements): + if isinstance(p, _StridedShard): + return i + + # rule 2 – col-parallel + for i, p in enumerate(placements): + if isinstance(p, Shard) and p.dim != 0: + return i + + # this is a special case, We do TP only + if tp_enabled and len(placements) == 1: + if isinstance(placements[0], Shard): + return 0 + return None # could not infer + + +def gather_tp_shard(tensor, tp_group, tp_world_size, original_placements): + # TP is used, we need to gather the TP-shard params first + tp_mesh_dim = tp_axis(original_placements, True) + assert tp_mesh_dim is not None, "something wrong here" + shard_dim = original_placements[tp_mesh_dim].dim + + output_tensors = [torch.empty_like(tensor) for _ in range(tp_world_size)] + dist.all_gather(output_tensors, tensor, group=tp_group) + return torch.cat(output_tensors, dim=shard_dim) + + +def calculate_shard_shape(shape, rank, world_size): + full = shape[0] + splits = torch.arange(full).chunk(world_size) + if rank >= len(splits): + dim0 = 0 + else: + dim0 = len(splits[rank]) + + return (dim0, *shape[1:]) + + +class DistributedScion(torch.optim.Optimizer): + def __init__( + self, + params, + is_light, + weight_decay, + lr, + momentum, + nesterov, + eps, + norm_factor, + backend, + backend_steps, + parallel_dims, + communication_dtype=torch.bfloat16, + extra_reduce_for_HSDP=False, + experts_weights_layout="G-D_out-D_in", + ): + self.need_to_calculate_norm = False + self.norms_to_log: list[str] = list(NORM_FUNCTIONS.keys()) + self.norms_at_current_step = {} + self.extra_reduce_for_HSDP = False + self.log_parameters_types = True + + defaults = dict( + lr=lr, + momentum=momentum, + weight_decay=weight_decay, + nesterov=nesterov, + eps=eps, + norm_factor=norm_factor, + backend=backend, + backend_steps=backend_steps, + ) + self.is_light = is_light + + assert self.is_light is False, " light mode not tested yet" + + is_unconstrained = weight_decay == 0 + + self.world_mesh = parallel_dims.world_mesh + + self.fsdp_enabled = parallel_dims.fsdp_enabled + self.expert_enabled = parallel_dims.ep_enabled + self.dp_replicate_enabled = parallel_dims.dp_replicate_enabled + self.tp_enabled = parallel_dims.tp_enabled + + # this is used to ensure only the DP or FSDP rank 0 will have norms + self.is_dp_rank_0 = dist.get_rank(self.world_mesh["dp_cp"].get_group()) == 0 + + assert experts_weights_layout in [ + "G-D_in-D_out", + "G-D_out-D_in", + ], f"Unknown experts weights layout: {experts_weights_layout}" + self.experts_need_transpose = experts_weights_layout == "G-D_in-D_out" + self.extra_reduce_for_HSDP = extra_reduce_for_HSDP + + logger.info( + f"Distributed Scion optimizer " + f"(is_light={self.is_light}, is_unconstrained={is_unconstrained}) " + f"is enabled with world_mesh={self.world_mesh} | fsdp_enabled={self.fsdp_enabled} | " + f"EP={self.expert_enabled} | TP={self.tp_enabled} | DP={self.dp_replicate_enabled}" + ) + + super().__init__(params, defaults) + if self.is_light: + # Initialize state + self._store_grads_in_state() + # Do not pass `self` through syntactic sugar. We need the + # argument to not be populated. + self.register_state_dict_pre_hook( + type(self)._store_grads_in_state, + ) + self.register_load_state_dict_post_hook( + type(self)._load_grads_from_state, + ) + + self.communication_dtype = communication_dtype + self.groups_info = {} + self.parameters_to_groups = {} + for group_idx, group in enumerate(self.param_groups): + lr = group["lr"] + nesterov = group["nesterov"] + momentum = group["momentum"] + wd = group["weight_decay"] + param_kwargs = { + "eps": group["eps"], + "norm_factor": group["norm_factor"], + "zeropower_backend": group["backend"], + "backend_steps": group["backend_steps"], + } + self.groups_info[group_idx] = [lr, nesterov, momentum, wd, param_kwargs] + for param in group["params"]: + self.parameters_to_groups[id(param)] = group_idx + + if self.is_light and nesterov: + raise RuntimeError( + "Nesterov momentum is not supported for Scion's light mode. " + "Please set nesterov=False." + ) + + def calculate_norm_at_next_step(self, norms_to_log: list[str]): + self.need_to_calculate_norm = True + self.norms_to_log = norms_to_log + self.norms_at_current_step = {} + + def get_norms_at_current_step(self): + if self.is_dp_rank_0: + return self.norms_at_current_step + else: + return {} + + def zero_grad(self, *args, **kwargs): + if self.is_light: + pass + else: + super().zero_grad(*args, **kwargs) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + scale_params, scale_param_names = [], [] + embed_params, embed_param_names = [], [] + ddp_params, ddp_param_names = [], [] + fsdp_params, fsdp_param_names = [], [] + expert_params, expert_param_names = [], [] + + for group_idx, group in enumerate(self.param_groups): + # we should update self.groups_info here incase we have LR and momentum scheduler + # We can also optionally do norm_factor and backend scheduler if we want to + lr = group["lr"] + nesterov = group["nesterov"] + momentum = group["momentum"] + wd = group["weight_decay"] + param_kwargs = { + "eps": group["eps"], + "norm_factor": group["norm_factor"], + "zeropower_backend": group["backend"], + "backend_steps": group["backend_steps"], + } + self.groups_info[group_idx] = [lr, nesterov, momentum, wd, param_kwargs] + + for p_name, p in zip(group["param_names"], group["params"]): + norm_factor = group["norm_factor"] + backend = group["backend"] + is_embed_norm = norm_factor.startswith( + "embed" + ) or norm_factor.startswith("unembed") + + if p.numel() == 1: + assert ( + backend == "identity" + ), "scale params must use identity backend" + assert ( + norm_factor == "sign" + ), "scale params must use sign norm factor" + scale_params.append(p) + scale_param_names.append(p_name) + continue + + if backend == "identity" and is_embed_norm: + # for these Row/Col-wise norm, there is no need to gather the gradient + embed_params.append(p) + embed_param_names.append(p_name) + continue + + param_type = get_param_type(p, self.fsdp_enabled, self.expert_enabled) + if param_type == ParamType.DDP: + ddp_params.append(p) + ddp_param_names.append(p_name) + elif param_type == ParamType.FSDP: + fsdp_params.append(p) + fsdp_param_names.append(p_name) + elif param_type == ParamType.Expert: + expert_params.append(p) + expert_param_names.append(p_name) + elif param_type == ParamType.Unknown: + logger.warning( + f"Unknown param type: {p_name}, p.shape {p.shape}, grad is None[?] " + f"{p.grad is None}, the optimizer will skip this param" + ) + # raise ValueError(f"Unknown param type: {p_name}") + continue + else: + raise ValueError("param_type") + + # Sort fsdp_params and their names together + fsdp_pairs = list(zip(fsdp_params, fsdp_param_names)) + fsdp_pairs.sort(key=lambda x: x[0].numel(), reverse=True) + fsdp_params, fsdp_param_names = zip(*fsdp_pairs) if fsdp_pairs else ([], []) + # Sort expert_params and their names together + expert_pairs = list(zip(expert_params, expert_param_names)) + expert_pairs.sort(key=lambda x: (x[0].numel(), x[0].shape[1]), reverse=True) + expert_params, expert_param_names = ( + zip(*expert_pairs) if expert_pairs else ([], []) + ) + if self.log_parameters_types: + # only log once + logger.info( + f"fsdp_params: {len(fsdp_params)} | expert_params: {len(expert_params)} | " + f"ddp_params: {len(ddp_params)} | embed_params: {len(embed_params)} | " + f"scale_params: {len(scale_params)}" + ) + self.log_parameters_types = False + + """ + We could merge `embed_params` and `expert_params` into one list. + The diff is, we are sure expert_params have bunch of 2D full-matrixs + But we might need to gather the `embed_params` to 2D full-matrixs + if we wanna to get the norm of the gradient. + """ + self.step_scalar(scale_params, scale_param_names) + self.step_embedding(embed_params, embed_param_names) + self.step_experts(expert_params, expert_param_names) + self.step_ddp(ddp_params, ddp_param_names) + self.step_fsdp(fsdp_params, fsdp_param_names) + + # reset the flag for the next step + self.need_to_calculate_norm = False + return loss + + @torch.no_grad() + def lmo( + self, + g, + eps, + norm_factor, + zeropower_backend, + backend_steps, + ): + g = g.to_local() if isinstance(g, DTensor) else g + + # NB: make sure this function does not modify the grad inplace + # since it is also called during the log of gradients + def _lmo_for_2d_tensor(g, need_transpose=False): + g = g if not need_transpose else g.transpose(0, 1) + g = zeropower_backends[zeropower_backend](g, steps=backend_steps, eps=eps) + g = self.normalise_grad(g, norm_factor=norm_factor, eps=eps) + return g if not need_transpose else g.transpose(0, 1) + + if g.ndim == 2: + g = _lmo_for_2d_tensor(g, need_transpose=False) + elif g.ndim == 3: + if g.shape[0] > 0: + # When world_size [fsdp x EP] > Total number of experts, + # some ranks may have 0 experts that shape will be [0, d-in, d-out] + # We should return the original grad here and **do not** do stack + g = torch.stack( + [ + _lmo_for_2d_tensor( + g[i], need_transpose=self.experts_need_transpose + ) + for i in range(g.shape[0]) + ], + dim=0, + ) + else: + pass + elif g.ndim == 1: + if zeropower_backend != "identity": + g_diag = torch.diag_embed(g).contiguous() + result_diag = _lmo_for_2d_tensor(g_diag) + g = result_diag.diagonal().contiguous() + else: + g = _lmo_for_2d_tensor(g) + + # TODO(JSC): JUST HARD CODE IT TO USE 'identity' backend and 'bias_rms' norm_factor for + # now until we add regex to extra the norm's weights + # zeropower_backend = "identity" + # norm_factor = "bias_rms" + # g = _lmo_for_2d_tensor(g) + + else: + raise ValueError(f"Unknown grad shape: {g.shape}") + + return g + + @torch.no_grad() + def normalise_grad(self, g, norm_factor, eps): + if norm_factor == "spectral": + g = g * (g.size(0) / g.size(1)) ** 0.5 + elif norm_factor == "image_spectral": + g = g * max((g.size(0) / g.size(1)) ** 0.5, 1) + elif norm_factor.startswith("embed"): + # NB: here assume shape [vocab_size, embed_dim] + rms_values = torch.sqrt(g.pow(2).sum(axis=1, keepdim=True)) + g = g / (rms_values + eps) + if norm_factor == "embed_linear": + g = g * g.size(1) + elif norm_factor == "embed_sqrt": + g = g * g.size(1) ** 0.5 + else: + raise ValueError(f"Unknown norm_factor: {norm_factor}") + elif norm_factor.startswith("unembed"): + rms_values = torch.sqrt(g.pow(2).sum(axis=1, keepdim=True)) + g = g / (rms_values + eps) + if norm_factor == "unembed_linear": + g = g / g.size(1) + elif norm_factor == "unembed_sqrt": + g = g / g.size(1) ** 0.5 + else: + raise ValueError(f"Unknown norm_factor: {norm_factor}") + elif norm_factor == "sign": + g = torch.sign(g) + elif norm_factor == "bias_rms": + rms_value = torch.sqrt(g.pow(2).mean()) + g = g / (rms_value + eps) + elif norm_factor == "none": + pass + else: + raise ValueError(f"Unknown norm_factor: {norm_factor}") + + return g + + def __getstate__(self): + self._store_grads_in_state() + return super().__getstate__() + + def __setstate__(self, state): + super().__setstate__(state) + self._load_grads_from_state() + + def _store_grads_in_state(self): + for group in self.param_groups: + for param in group["params"]: + if isinstance(param, torch.Tensor) and param.grad is not None: + self.state.setdefault(param, {})["grad_state"] = param.grad + + def _load_grads_from_state(self): + for param, state in self.state.items(): + if "grad_state" in state: + param.grad = state["grad_state"] + elif isinstance(param, torch.Tensor): + param.grad = None + + def update_bucket_params(self, params, updates, start_idx, end_idx, tp_group=None): + # TODO(JSC): we could maybe use tesnor update rather than for-loop here + # can be helpful for FSDP and EP params + for idx_in_bucket in range(start_idx, end_idx): + shift = idx_in_bucket - start_idx + p = params[idx_in_bucket] + u = updates[shift] + lr, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + + if wd != 0: + # p.data.mul_(1 - wd*lr) + p.mul_(1 - wd * lr) + + if isinstance(p, DTensor) and self.tp_enabled: + original_placements = p.placements + tp_mesh_dim = tp_axis(original_placements, p.shape == u.shape) + + if isinstance(p, DTensor): + if tp_group is None or tp_mesh_dim is None: + p.to_local().add_(u, alpha=-lr) + else: + tp_rank = tp_group.rank() + tp_sharded_dim = original_placements[tp_mesh_dim].dim + chunk_size = p.to_local().shape[tp_sharded_dim] + start_offset = tp_rank * chunk_size + + slicer = [slice(None)] * u.dim() + slicer[tp_sharded_dim] = slice( + start_offset, start_offset + chunk_size + ) + u_sliced = u[slicer] + p.to_local().add_(u_sliced, alpha=-lr) + else: + p.add_(u, alpha=-lr) + + if momentum != 1 and self.is_light and p.grad is not None: + p.grad.mul_(1 - momentum) + + def step_scalar( + self, + scalar_params, + scalar_param_names, + skip_update=False, + apply_on_weight=True, + ): + if len(scalar_params) == 0: + return {} + + need_to_calculate_norm = self.need_to_calculate_norm + + final_norms = {} + apply_on_weight = apply_on_weight and need_to_calculate_norm + + for param_idx in range(len(scalar_params)): + p = scalar_params[param_idx] + lr, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + g = self.get_momentum_or_grad( + p, momentum, nesterov, update_buffer=True, gather_to_local=False + ) + g = g.to_local() if isinstance(g, DTensor) else g + + # the lmo of scalar is just sign + u = torch.sign(g) + + if not skip_update: + self.update_bucket_params([p], [u], 0, 1) + + if need_to_calculate_norm: + cleaned_p_name = remove_orig_mod_and_weight_for_p_name( + scalar_param_names[param_idx] + ) + p = p.to_local() if isinstance(p, DTensor) else p + # final_norms[f"scalar_update_supremum/{cleaned_p_name}"] = -lr * u + # seems no need to log the update norm ? its always equals to LR + final_norms[f"scalar_param_supremum/{cleaned_p_name}"] = p.abs() + + self.norms_at_current_step.update(final_norms) + + def step_embedding( + self, + embed_params, + embed_param_names, + skip_update=False, + apply_on_weight=True, + ): + if len(embed_params) == 0: + return {} + + need_to_calculate_norm = self.need_to_calculate_norm + + tp_group = None + # if self.dp_replicate_enabled: + # dp_replicate_group = self.world_mesh["dp_replicate"].get_group() + # else: + # dp_replicate_group = None + + if self.tp_enabled: + tp_group = self.world_mesh["tp"].get_group() + + norms_of_update, norms_of_weight, final_norms = [], [], {} + apply_on_weight = apply_on_weight and need_to_calculate_norm + + for param_idx in range(len(embed_params)): + p = embed_params[param_idx] + lr, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + g = self.get_momentum_or_grad( + p, momentum, nesterov, update_buffer=True, gather_to_local=False + ) + + u = self.lmo(g, **param_kwargs) + + ######################################################### + # # As of we use norm for Embedding, maybe we should not do Reduce here + # if ( + # dp_replicate_group is not None + # and self.extra_reduce_for_HSDP + # and self.fsdp_enabled + # ): + # dist.all_reduce(u, group=dp_replicate_group, op=dist.ReduceOp.AVG) + # dist.barrier(group=dp_replicate_group) + if not skip_update: + self.update_bucket_params([p], [u], 0, 1, tp_group=tp_group) + + if not need_to_calculate_norm: + return {} + + # for the embedding, if we want to calculate the norm, we need to gather the gradient + for param_idx in range(len(embed_params)): + p, p_name = embed_params[param_idx], embed_param_names[param_idx] + lr, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + # this is important, *Do NOT* update buffer twice here + g = self.get_momentum_or_grad( + p, momentum, nesterov, update_buffer=False, gather_to_local=True + ) + """ + TODO(JSC): maybe we can improve this [?] + Rather than Gather - LMO, we can do LMO - Gather such that we can avoid compute + lmo twice [?]. though lmo of embedding is not expensive + [kind of fixed], we dont do gather here anymore + """ + u = self.lmo(g, **param_kwargs) + + if apply_on_weight and isinstance(p, DTensor): + p = p.full_tensor() + + norm_need_transpose = "tok_embeddings" in p_name + norms_of_update = calculate_norm( + -lr * u, self.norms_to_log, transpose=norm_need_transpose + ) + if apply_on_weight: + norms_of_weight: dict = calculate_norm( + p, self.norms_to_log, transpose=norm_need_transpose + ) + else: + norms_of_weight = None + + # This should _not_ be an f-string since the variable names + # will be interpolated later. + embed_norm_key_template = "track_{task_name}_{norm_name}/{cleaned_p_name}" + cleaned_p_name = remove_orig_mod_and_weight_for_p_name(p_name) + for norm_name in self.norms_to_log: + final_norms[ + embed_norm_key_template.format( + task_name="update", + norm_name=norm_name, + cleaned_p_name=cleaned_p_name, + ) + ] = norms_of_update[norm_name] + if apply_on_weight: + final_norms[ + embed_norm_key_template.format( + task_name="param", + norm_name=norm_name, + cleaned_p_name=cleaned_p_name, + ) + ] = norms_of_weight[norm_name] + self.norms_at_current_step.update(final_norms) + + def step_experts( + self, + expert_params, + expert_param_names, + skip_update=False, + apply_on_weight=True, + ): + if len(expert_params) == 0: + return {} + + need_to_calculate_norm = self.need_to_calculate_norm + + norms_of_update, norms_of_weight, final_norms = [], [], {} + apply_on_weight = apply_on_weight and need_to_calculate_norm + + device = expert_params[0].device + fsdp_group = self.world_mesh["dp_shard_cp"].get_group() + world_size = dist.get_world_size(fsdp_group) + local_rank = dist.get_rank(fsdp_group) + ep_per_rank = math.ceil(expert_params[0].shape[0] / world_size) + + kinds_of_norms = len(self.norms_to_log) + + padding_norms = torch.tensor(0.0, device=device) + # each rank will process `len(expert_params) * ep_per_rank` experts + # each expert will have `self.norms_to_log` norms + # so each rank will have `len(expert_params) * ep_per_rank * len(self.norms_to_log)` + # norms + # globally, its [[g0-ep0, g0-ep1, g0-ep2, ...], [g1-ep0, g1-ep1, g1-ep2, ...], ...] on each + # rank + + transpose = self.experts_need_transpose + for param_idx in range(len(expert_params)): + p = expert_params[param_idx] + lr, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + g = self.get_momentum_or_grad( + p, momentum, nesterov, update_buffer=True, gather_to_local=False + ) + u = self.lmo(g, **param_kwargs) + + if not skip_update: + self.update_bucket_params([p], [u], 0, 1) + + if need_to_calculate_norm: + # cleaned_p_name = remove_orig_mod_and_weight_for_p_name( + # expert_param_names[param_idx] + # ) + assert u.ndim == 3 + for ep_idx in range(u.shape[0]): + update_norms = calculate_norm( + u[ep_idx], self.norms_to_log, transpose=transpose + ) + # Template for MoE norm keys + norms_of_update.extend(update_norms.values()) + if apply_on_weight: + weight_norms = calculate_norm( + p.to_local()[ep_idx], self.norms_to_log, transpose=transpose + ) + norms_of_weight.extend(weight_norms.values()) + + if need_to_calculate_norm: + expected_total = len(expert_params) * ep_per_rank * kinds_of_norms + pad_needed = expected_total - len(norms_of_update) + if pad_needed > 0: + norms_of_update.extend([padding_norms] * pad_needed) + if apply_on_weight: # keep weight-norms aligned + norms_of_weight.extend([padding_norms] * pad_needed) + + norms_tensor = torch.stack(norms_of_update).float().to(device) + gathered_update_norms = torch.empty( + world_size * norms_tensor.shape[0], + dtype=norms_tensor.dtype, + device=norms_tensor.device, + ) + dist.all_gather_into_tensor( + gathered_update_norms, norms_tensor, group=fsdp_group + ) + + if apply_on_weight: + norms_tensor = torch.stack(norms_of_weight).float().to(device) + gathered_weight_norms = torch.empty( + world_size * norms_tensor.shape[0], + dtype=norms_tensor.dtype, + device=norms_tensor.device, + ) + dist.barrier() + dist.all_gather_into_tensor( + gathered_weight_norms, norms_tensor, group=fsdp_group + ) + + if local_rank == 0: + norm_names = list(self.norms_to_log) + + P = len(expert_params) # parameters per rank + E = ep_per_rank # experts per rank + K = kinds_of_norms # norms per expert + block = P * E * K # values contributed by each rank + + for idx in range(world_size * block): + r, rem = divmod(idx, block) # producing rank + p, rem = divmod(rem, E * K) # parameter index + e, k = divmod(rem, K) # expert, norm indices + + actual_ep_idx = e + r * E + if actual_ep_idx >= expert_params[0].shape[0]: + continue # skip pure padding slots + + cleaned_name = remove_orig_mod_and_weight_for_p_name( + expert_param_names[p] + ) + norm_name = norm_names[k] + + key_update = ( + f"track_update_{norm_name}/ep_{actual_ep_idx}/{cleaned_name}" + ) + final_norms[key_update] = gathered_update_norms[idx] + + if apply_on_weight: + key_param = ( + f"track_param_{norm_name}/ep_{actual_ep_idx}/{cleaned_name}" + ) + final_norms[key_param] = gathered_weight_norms[idx] + + self.norms_at_current_step.update(final_norms) + + def step_ddp( + self, + ddp_params, + ddp_param_names, + skip_update=False, + apply_on_weight=True, + ): + # Either we do DDP + # or we do TP, there is no [DDP + TP] case but for safety we add sevel checks + # if len(ddp_params) == 0: + # return {} + + need_to_calculate_norm = self.need_to_calculate_norm + + tp_group, dp_replicate_group = None, None + + rank = 0 + bucket_size = world_size = 1 + total_buckets = len(ddp_params) + + if self.dp_replicate_enabled: + dp_replicate_group = self.world_mesh["dp_replicate"].get_group() + world_size = dp_replicate_group.size() + rank = dp_replicate_group.rank() + + bucket_size = world_size + total_buckets = math.ceil(len(ddp_params) / bucket_size) + + if self.tp_enabled: + tp_group = self.world_mesh["tp"].get_group() + tp_world_size = dist.get_world_size(group=tp_group) + + device = ddp_params[0].device if len(ddp_params) > 0 else torch.device("cuda") + cast_dtype = self.communication_dtype + zero_tensor = partial(torch.zeros, dtype=cast_dtype, device=device) + + norms_of_update, norms_of_weight, final_norms = [], [], {} + padding_norms = { + norm_name: torch.tensor(0.0, device=device) + for norm_name in self.norms_to_log + } + apply_on_weight = apply_on_weight and need_to_calculate_norm + + # for DDP, we need to first update the buffer + for param_idx in range(len(ddp_params)): + p = ddp_params[param_idx] + _, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + g = self.get_momentum_or_grad( + p, momentum, nesterov, update_buffer=True, gather_to_local=False + ) + + # then we do scion stuff + for bucket_idx in range(total_buckets): + start_idx = bucket_idx * bucket_size + end_idx = min(start_idx + bucket_size, len(ddp_params)) + current_rank_idx = start_idx + rank + if current_rank_idx < len(ddp_params): + p = ddp_params[current_rank_idx] + # Step 1: Get the gradient + _, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + g = self.get_momentum_or_grad( + p, momentum, nesterov, update_buffer=False, gather_to_local=False + ) + if isinstance(g, DTensor) and self.tp_enabled: + g = gather_tp_shard( + g.to_local(), tp_group, tp_world_size, g.placements + ).to(dtype=cast_dtype) + + else: + # To avoid idle stream, we pad the last rank + p = ddp_params[end_idx - 1] + g = zero_tensor(p.shape) + _, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + + # step 2: lmo + u = self.lmo(g, **param_kwargs) + + if not skip_update: + # Step 3: FOR DDP, we do all-gather + if self.dp_replicate_enabled: + # only gather params when we doing DDP + BUCKET + gather_lists = [None] * world_size + for i in range(world_size): + param_idx = start_idx + i + if i == rank or param_idx >= len(ddp_params): + gather_lists[i] = u.to(dtype=cast_dtype) + elif param_idx < len(ddp_params): + p = ddp_params[start_idx + i] + gather_lists[i] = zero_tensor(p.shape) + dist.all_gather( + gather_lists, u.to(dtype=cast_dtype), group=dp_replicate_group + ) + if self.tp_enabled: + # only if DP+TP we need to barrier here other-wise its automatically synced + dist.barrier(group=dp_replicate_group) + else: + # other wise (TP only), dp world_size is 1 + gather_lists = [u.to(dtype=cast_dtype)] + + # Step 4: Update the parameters + self.update_bucket_params( + ddp_params, gather_lists, start_idx, end_idx, tp_group=tp_group + ) + + if need_to_calculate_norm: + # so here, we already have update of each rank + p = ddp_params[min(current_rank_idx, len(ddp_params) - 1)] + lr, *_ = self.groups_info[self.parameters_to_groups[id(p)]] + + if current_rank_idx < end_idx: + norms_of_update.extend( + calculate_norm(-lr * u, self.norms_to_log).values() + ) + else: + norms_of_update.extend(padding_norms.values()) + if apply_on_weight: + if current_rank_idx < end_idx: + if isinstance(p, DTensor) and self.tp_enabled: + p = gather_tp_shard( + p.to_local(), tp_group, tp_world_size, p.placements + ).to(dtype=cast_dtype) + + norms_of_weight.extend( + calculate_norm(p, self.norms_to_log).values() + ) + else: + norms_of_weight.extend(padding_norms.values()) + + if need_to_calculate_norm and len(norms_of_update) > 0: + + norms_tensor = torch.stack(norms_of_update).to(device=device).float() + if self.dp_replicate_enabled: + gathered_update_norms = torch.empty( + world_size * norms_tensor.shape[0], + dtype=norms_tensor.dtype, + device=norms_tensor.device, + ) + dist.barrier(group=dp_replicate_group) + dist.all_gather_into_tensor(gathered_update_norms, norms_tensor) + else: + gathered_update_norms = norms_tensor + + if apply_on_weight: + norms_tensor = torch.stack(norms_of_weight).to(device=device).float() + if self.dp_replicate_enabled: + gathered_weight_norms = torch.empty( + world_size * norms_tensor.shape[0], + dtype=norms_tensor.dtype, + device=norms_tensor.device, + ) + dist.barrier(group=dp_replicate_group) + dist.all_gather_into_tensor(gathered_weight_norms, norms_tensor) + else: + gathered_weight_norms = norms_tensor + + if rank == 0: + # This should _not_ be an f-string since the variable + # names will be interpolated later. + ddp_norm_key_template = "track_{task_name}_{norm_name}/{cleaned_p_name}" + num_norm_types = len(self.norms_to_log) + # total_buckets is already defined + + for param_idx, p_name in enumerate(ddp_param_names): + cleaned_p_name = remove_orig_mod_and_weight_for_p_name(p_name) + + # Determine which rank and bucket handled this parameter + param_rank = param_idx % world_size + param_bucket = param_idx // world_size + + for norm_idx, norm_name in enumerate(self.norms_to_log): + # Calculate the correct index in the gathered tensor + base_idx = ( + param_rank * total_buckets + param_bucket + ) * num_norm_types + norm_value_idx = base_idx + norm_idx + + final_norms[ + ddp_norm_key_template.format( + task_name="update", + norm_name=norm_name, + cleaned_p_name=cleaned_p_name, + ) + ] = gathered_update_norms[norm_value_idx] + + if apply_on_weight: + final_norms[ + ddp_norm_key_template.format( + task_name="param", + norm_name=norm_name, + cleaned_p_name=cleaned_p_name, + ) + ] = gathered_weight_norms[norm_value_idx] + dist.barrier() + + self.norms_at_current_step.update(final_norms) + + def step_fsdp( + self, + fsdp_params, + fsdp_param_names, + skip_update=False, + apply_on_weight=True, + ): + if len(fsdp_params) == 0: + return {} + need_to_calculate_norm = self.need_to_calculate_norm + tp_group, dp_replicate_group = None, None + """ + To make FSDP+DP works, we lets step_fsdp work on each dp_replicate separately. + Hence, we only care about the world size inside the dp_replicate. + """ + + # due to the werid implementation of parallel_dims.py (upstream) + # here we should use `dp_shard_cp` rather then `dp_shard` as of + # CP is also part of the dp_shard + fsdp_group = self.world_mesh["dp_shard_cp"].get_group() + + if self.dp_replicate_enabled: + dp_replicate_group = self.world_mesh["dp_replicate"].get_group() + + if self.tp_enabled: + tp_group = self.world_mesh["tp"].get_group() + tp_world_size = dist.get_world_size(group=tp_group) + + world_size = dist.get_world_size(fsdp_group) + rank = dist.get_rank(fsdp_group) + + # @ THIS IS A HACK + bucket_size = world_size + total_buckets = math.ceil(len(fsdp_params) / bucket_size) + + device = fsdp_params[0].device + cast_dtype = self.communication_dtype + zero_tensor = partial(torch.empty, dtype=cast_dtype, device=device) + + norms_of_update, norms_of_weight, final_norms = [], [], {} + + padding_norms = { + norm_name: torch.tensor(0.0, device=device) + for norm_name in self.norms_to_log + } + + apply_on_weight = apply_on_weight and need_to_calculate_norm + + # Process each bucket + for bucket_idx in range(total_buckets): + start_idx = bucket_idx * bucket_size + end_idx = min(start_idx + bucket_size, len(fsdp_params)) + + # Step 1: Prepare data for first all_to_all + grads_send_list, send_shapes = [], [] + target_shape, param_kwargs = None, None + + for rank_idx in range(world_size): + current_rank_idx = start_idx + rank_idx + + if current_rank_idx < len(fsdp_params): + p = fsdp_params[current_rank_idx] + _, nesterov, momentum, wd, param_kwargs = self.groups_info[ + self.parameters_to_groups[id(p)] + ] + + g = self.get_momentum_or_grad( + p, + momentum, + nesterov, + update_buffer=True, + gather_to_local=False, + ) + + original_placements = g.placements + tp_mesh_dim = tp_axis(original_placements) + if tp_group is not None and tp_mesh_dim is not None: + # the reason we need `tp_mesh_dim` is we want a flexible solution + # that Attention go TP and MLP go EP + g = gather_tp_shard( + g.to_local(), tp_group, tp_world_size, original_placements + ).to(dtype=cast_dtype) + else: + g = g.to_local().to(dtype=cast_dtype) + + # Save the shape info for this parameter + if rank == rank_idx: + target_shape = p.shape + else: + # Use a dummy shape for parameters beyond our range + p = fsdp_params[end_idx - 1] + g = zero_tensor(p.to_local().shape) + + grads_send_list.append(g) + send_shapes.append(g.shape) + + # Make sure target_shape is initialized + # (trigger by the padding of the last ranks) + if target_shape is None and end_idx > 0: + target_shape = fsdp_params[end_idx - 1].shape + param_kwargs = self.groups_info[ + self.parameters_to_groups[id(fsdp_params[end_idx - 1])] + ][-1] + + recv_shapes = [ + calculate_shard_shape(target_shape, rank_idx, world_size) + for rank_idx in range(world_size) + ] + recv_list = [zero_tensor(shape) for shape in recv_shapes] + # Step 3: First all_to_all - using ASYNC version + dist.barrier() + dist.all_to_all(recv_list, grads_send_list, group=fsdp_group) + # Step 5: Concatenate received gradients along dimension 0 and perform NS5 + # All tensors in recv_list should have the same dimensions except for dim 0 + + full_g = torch.cat(recv_list, dim=0) + u = self.lmo(full_g, **param_kwargs) + dist.barrier(group=fsdp_group) + + if dp_replicate_group is not None and self.extra_reduce_for_HSDP: + dist.all_reduce(u, group=dp_replicate_group, op=dist.ReduceOp.AVG) + dist.barrier(group=dp_replicate_group) + # in case of FSDP+DP, we can do a All-Reduce here sync the grads + if not skip_update: + # Step 6: Split the processed tensor back for second all_to_all + split_sizes = [shape[0] for shape in recv_shapes] + + grads_send_list = list(torch.split(u, split_sizes, dim=0)) + recv_list = [zero_tensor(shape) for shape in send_shapes] + # Step 8: Second all_to_all - using ASYNC version + dist.all_to_all(recv_list, grads_send_list, group=fsdp_group) + del grads_send_list + # Step 10: Update parameters using the results + self.update_bucket_params( + fsdp_params, + recv_list, + start_idx, + end_idx, + tp_group=tp_group, + ) + + if need_to_calculate_norm: + if start_idx + rank < end_idx: + lr, *_ = self.groups_info[ + self.parameters_to_groups[id(fsdp_params[start_idx + rank])] + ] + norms = calculate_norm(-lr * u, self.norms_to_log) + else: + norms = padding_norms + norms_of_update.extend(norms.values()) + + if apply_on_weight: + params_send_list = [] + for rank_idx in range(world_size): + current_rank_idx = start_idx + rank_idx + if current_rank_idx < len(fsdp_params): + p = fsdp_params[current_rank_idx] + else: + p = fsdp_params[end_idx - 1] + + # her is patch for FSDP+TP + original_placements = p.placements + tp_mesh_dim = tp_axis(original_placements) + if tp_group is not None and tp_mesh_dim is not None: + p = gather_tp_shard( + p.to_local(), + tp_group, + tp_world_size, + original_placements, + ).to(dtype=cast_dtype) + else: + p = p.to_local().to(dtype=cast_dtype) + + params_send_list.append(p) + + recv_list = [zero_tensor(shape) for shape in recv_shapes] + + dist.barrier(group=fsdp_group) + dist.all_to_all(recv_list, params_send_list, group=fsdp_group) + + full_weight = torch.cat(recv_list, dim=0) + + if start_idx + rank < end_idx: + norms = calculate_norm(full_weight, self.norms_to_log) + else: + norms = padding_norms + norms_of_weight.extend(norms.values()) + + # Below we need to all-gather the norms of update to rank-0 + if need_to_calculate_norm and len(norms_of_update) > 0: + # Convert norms_of_update to a flat tensor for all-gather + # Each rank has bucket_size * len(self.norms_to_log) norm values + norms_tensor = torch.stack(norms_of_update).to(device=device).float() + gathered_update_norms = torch.empty( + world_size * norms_tensor.shape[0], + dtype=norms_tensor.dtype, + device=norms_tensor.device, + ) + dist.barrier(group=fsdp_group) + dist.all_gather_into_tensor( + gathered_update_norms, norms_tensor, group=fsdp_group + ) + + if apply_on_weight: + norms_tensor = torch.stack(norms_of_weight).to(device=device).float() + gathered_weight_norms = torch.empty( + world_size * norms_tensor.shape[0], + dtype=norms_tensor.dtype, + device=norms_tensor.device, + ) + dist.barrier(group=fsdp_group) + dist.all_gather_into_tensor( + gathered_weight_norms, norms_tensor, group=fsdp_group + ) + + if rank == 0: + # ----- only rank 0 reconstructs ----- + num_norm_types = len(self.norms_to_log) + entries_per_rank = total_buckets * num_norm_types + + # pre-clean names once + cleaned_names = [ + remove_orig_mod_and_weight_for_p_name(pn) for pn in fsdp_param_names + ] + + fsdp_norm_key_template = ( + "track_{task_name}_{norm_name}/{cleaned_p_name}" + ) + + assert ( + gathered_update_norms.numel() == world_size * entries_per_rank + ), "update norms size mismatch" + if apply_on_weight: + assert ( + gathered_weight_norms.numel() == world_size * entries_per_rank + ), "weight norms size mismatch" + + for param_idx, cleaned_p_name in enumerate(cleaned_names): + r = param_idx % world_size # rank that owned this param + b = param_idx // world_size # bucket index on that rank + base = r * entries_per_rank + b * num_norm_types + + for norm_idx, norm_name in enumerate(self.norms_to_log): + idx = base + norm_idx + # Index must exist thanks to padding + final_norms[ + fsdp_norm_key_template.format( + task_name="update", + norm_name=norm_name, + cleaned_p_name=cleaned_p_name, + ) + ] = gathered_update_norms[idx] + + if apply_on_weight: + final_norms[ + fsdp_norm_key_template.format( + task_name="param", + norm_name=norm_name, + cleaned_p_name=cleaned_p_name, + ) + ] = gathered_weight_norms[idx] + + dist.barrier(group=fsdp_group) + self.norms_at_current_step.update(final_norms) + + @torch.no_grad() + def get_momentum_or_grad( + self, p, momentum, nesterov, update_buffer=False, gather_to_local=False + ): + g = p.grad + if g is None or not p.requires_grad: + return None + + use_momentum = momentum > 0 and momentum < 1 + + if not self.is_light and use_momentum: + state = self.state[p] + if "momentum_buffer" not in state.keys(): + if update_buffer: + state["momentum_buffer"] = torch.zeros_like(g) + else: + """ + When you using DDP + Dist-muon,you might trieer an error here. + Because in the optimizer.log you try to log all gradient's norm. + But for DDP + Dist-muon, each rank only has a part of the gradient. + + -- + For debug, you can return None here. + """ + raise ValueError( + "Momentum buffer not found in optimizer state. " + "Please check if the optimizer is initialized correctly." + ) + buf = state["momentum_buffer"] + if update_buffer: + buf.mul_(1 - momentum).add_(g, alpha=momentum) + else: + buf = buf.mul(1 - momentum).add(g, alpha=momentum) + g = buf if not nesterov else buf.mul(1 - momentum).add(g, alpha=momentum) + + if gather_to_local and isinstance(g, DTensor): + g = g.redistribute(placements=[Replicate()] * g.device_mesh.ndim).to_local() + return g diff --git a/torchtitan/experiments/distributed_scion/muon_utils.py b/torchtitan/experiments/distributed_scion/muon_utils.py new file mode 100644 index 000000000..27ed383ed --- /dev/null +++ b/torchtitan/experiments/distributed_scion/muon_utils.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import repeat + +import torch + + +def zeropower_via_svd(G, **kwargs): + original_dtype = G.dtype + G = G.to(torch.float32) + # SVD does not support bfloat16 + if G.size(0) > G.size(1): + G = G.T + transpose = True + else: + transpose = False + U, S, V = G.svd() + X = U @ V.T + if transpose: + X = X.T + return X.to(original_dtype).contiguous() + + +# Polar Express +@torch.compile +def zeropower_via_polar_express(G, steps=5, eps=1e-7): + # https://arxiv.org/abs/2505.16932 + coeffs_base = [ + (8.28721201814563, -23.595886519098837, 17.300387312530933), + (4.107059111542203, -2.9478499167379106, 0.5448431082926601), + (3.948690853482295, -2.908902115962949, 0.5518191394370137), + (3.318419657370602, -2.488488024314874, 0.51004894012372), + (2.300652019954817, -1.668903984574749, 0.4188073119525673), + (1.891301407787398, -1.267995827194587, 0.3768040894852483), + (1.875001480853448, -1.250001645399949, 0.3750001645474248), + (1.875000000000000, -1.250000000000000, 0.375000000000000), # limit + ] + + # apply the 1/1.01 stabiliser **only** to the first seven triples + coeffs_base = [ + (a / 1.01, b / 1.01**3, c / 1.01**5) for (a, b, c) in coeffs_base[:-1] + ] + [coeffs_base[-1]] + + # extend the list so that coeffs[k] is defined for every k < steps + coeffs = coeffs_base + list( + repeat(coeffs_base[-1], max(0, steps - len(coeffs_base))) + ) + + original_dtype = G.dtype + X = G.bfloat16() + if G.size(0) > G.size(1): + X = X.T + X = X / (torch.linalg.norm(X) + eps) # ensure top singular value <= 1 + + # main loop + for k in range(steps): + a, b, c = coeffs[k] + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + + if G.size(0) > G.size(1): + X = X.T + + return X.to(original_dtype) + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' \\sim Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + + assert ( + len(G.shape) == 2 + ), f"Please make sure gradients are 2D tensors to use NS, got shape: {G.shape}" + a, b, c = (3.4445, -4.7750, 2.0315) + # for a, b, c in [ # updated coefficients from @leloykun + # (4.0848, -6.8946, 2.9270), + # (3.9505, -6.3029, 2.6377), + # (3.7418, -5.5913, 2.3037), + # (2.8769, -3.1427, 1.2046), + # (2.8366, -3.0525, 1.2012), + # ]: + original_dtype = G.dtype + X = G.bfloat16() + if G.size(0) > G.size(1): + X = X.T + X = X / (torch.linalg.norm(X) + eps) # ensure top singular value <= 1 + + for _ in range(steps): + A = X @ X.T + B = ( + b * A + c * A @ A + ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(0) > G.size(1): + X = X.T + + return X.to(original_dtype) + + +zeropower_backends = dict( + svd=zeropower_via_svd, + newtonschulz5=zeropower_via_newtonschulz5, + polar_express=zeropower_via_polar_express, + identity=lambda x, **kwargs: x, +) diff --git a/torchtitan/experiments/distributed_scion/naive_param_norm.py b/torchtitan/experiments/distributed_scion/naive_param_norm.py new file mode 100644 index 000000000..d2c47ae14 --- /dev/null +++ b/torchtitan/experiments/distributed_scion/naive_param_norm.py @@ -0,0 +1,238 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import warnings + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import Replicate + +from torchtitan.experiments.distributed_scion import DistributedScion +from torchtitan.experiments.distributed_scion.norm_helper import calculate_norm + +from .utils import remove_orig_mod_and_weight_for_p_name + + +""" +This is the naive version of parameter norm calculation. +It is used to compute norm for other optimizers. +In distributed scion, we will automatically calcuate the norms in distributed mode. +""" + + +def gather_and_merge(local_stats: dict, dst: int = 0): + # this is the old-implementation of gather_and_merge, which is used + # to gather the norms of the parameters + # it only tested on the "FSDP-only" case. + world = dist.get_world_size() + rank = dist.get_rank() + dtype = torch.bfloat16 + + my_keys = list(local_stats.keys()) + + if len(my_keys) > 0: + val_tensor = torch.stack([local_stats[k].to(dtype) for k in my_keys]) + else: + my_keys = "padding" + val_tensor = None + + key_bucket = [None] * world if rank == dst else None + val_bucket = [None] * world if rank == dst else None + + dist.gather_object(my_keys, key_bucket, dst=dst) + # dist.barrier() + dist.gather_object(val_tensor, val_bucket, dst=dst) + dist.barrier() + + merged = {} + if rank == dst: + for peer, keys in enumerate(key_bucket): + if val_bucket[peer] is None: + continue + for k, v in zip(keys, val_bucket[peer]): + if k != "padding": + merged[k] = v + + dist.barrier() + if rank == dst: + return merged + else: + return {} + + +def compute_grad(p, optimizer=None, **kwargs): + if isinstance(optimizer, (Scion, DistributedScion)): + momentum = kwargs.pop("momentum") + nesterov = kwargs.pop("nesterov") + g = optimizer.get_momentum_or_grad( + p, + momentum, + nesterov, + update_buffer=False, + gather_to_local=optimizer.fsdp_enabled and p.ndim < 3, + # we do not gather the moe's grads + ) + if g is None: + return None + else: + g = g.to_local() if isinstance(g, DTensor) else g + return optimizer.lmo(g, **kwargs) + elif isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)): + if p.ndim == 3: + warnings.warn( + f"Optimizer {optimizer.__class__.__name__} does not support " + f"gradient computation for 3D tensors for logging." + ) + return None + + eps = kwargs["eps"] + weight_decay = kwargs["weight_decay"] + beta1, beta2 = kwargs["betas"] + assert weight_decay == 0.0, "Weight decay not supported for grad computation." + + param_optim_state = optimizer.state[p] + if "step" not in param_optim_state: + step = 0 + else: + step = param_optim_state["step"].item() + if "exp_avg_sq" in param_optim_state and "exp_avg" in param_optim_state: + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + denom = ( + param_optim_state["exp_avg_sq"].sqrt() / math.sqrt(bias_correction2) + ) + eps + step_size = 1 / bias_correction1 + g = step_size * param_optim_state["exp_avg"].div(denom) + else: + # TODO(JSC): if we shard the MoE model, we need to remove the following code + g = p.grad + + if isinstance(g, DTensor): + g = g.redistribute(placements=[Replicate()] * g.device_mesh.ndim) + return g + else: + raise TypeError( + f"Optimizer {optimizer.__class__.__name__} does not support " + f"gradient computation." + ) + + +def get_parameter_norms(model_parts, optimizers, norms_to_log): + all_norms = {} + for i, _ in enumerate(model_parts): + # NB: assumes correspondences between model parts and optimizers + optimizer = optimizers[i] + for group in optimizer.param_groups: + if isinstance(optimizer, (Scion, DistributedScion)): + param_kwargs = { + "momentum": group["momentum"], + "nesterov": group["nesterov"], + "eps": group["eps"], + "norm_factor": group["norm_factor"], + "zeropower_backend": group["backend"], + "backend_steps": group["backend_steps"], + } + elif isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)): + param_kwargs = { + "eps": group["eps"], + "betas": group["betas"], + "weight_decay": group["weight_decay"], + } + else: + warnings.warn( + f"Optimizer {optimizer.__class__.__name__} does not support " + f"norm computation." + ) + continue + + FLAG_NEED_SYNC = False + moe_norms, fsdp_norms = {}, {} + for p_name, p in zip(group["param_names"], group["params"]): + # The module is usually named + # `track_update_condition_number/model_part_0/layers.0._orig_mod.attention.wo.weight` + cleaned_p_name = remove_orig_mod_and_weight_for_p_name(p_name) + g = compute_grad(p, optimizer, **param_kwargs) + if g is None: + continue + assert not torch.isnan(g).any(), f"There is nan in the grad of {p_name}" + + if p.ndim < 3: + p = ( + p.redistribute(placements=[Replicate()] * p.device_mesh.ndim) + if isinstance(p, DTensor) + else p + ) + else: + FLAG_NEED_SYNC = True + local_rank = dist.get_rank() + world_size = dist.get_world_size() + ep_per_rank = math.ceil(p.shape[0] / world_size) + # We dont gather the parameters for 3D tensors, + # which is [G, D_in, D_out] of GroupedExperts + pass + p = p.to_local() if isinstance(p, DTensor) else p + g = g.to_local() if isinstance(g, DTensor) else g + update = -group["lr"] * g + + # #################################################### + for task, matrix in [("update", update), ("param", p)]: + if matrix.ndim == 3: + moe_norm_key_template = f"track_{task}_{{norm_name}}/ep_{{actual_ep_idx}}/{cleaned_p_name}" + for ep_idx in range(matrix.shape[0]): + actual_ep_idx = ep_idx + local_rank * ep_per_rank + update_norms = calculate_norm( + matrix[ep_idx], norms_to_log, transpose=True + ) + # Template for MoE norm keys + moe_norms.update( + { + moe_norm_key_template.format( + norm_name=norm_name, + actual_ep_idx=actual_ep_idx, + ): norm_value + for norm_name, norm_value in update_norms.items() + } + ) + else: + if matrix.ndim > 2: + warnings.warn( + f"Encountered parameter or update {cleaned_p_name} with " + f"shape {p.shape} or {update.shape}, respectively; " + f"this may not be an issue, but please ensure its " + f"norms are calculated correctly." + ) + + transpose = "tok_embeddings" in p_name + update_norms = calculate_norm( + matrix, + norms_to_log, + transpose=transpose, + ) + + # Template for FSDP norm keys + fsdp_norm_key_template = ( + f"track_{task}_{{norm_name}}/{cleaned_p_name}" + ) + fsdp_norms.update( + { + fsdp_norm_key_template.format( + norm_name=norm_name + ): norm_value + for norm_name, norm_value in update_norms.items() + } + ) + + if FLAG_NEED_SYNC: + # remove the comment below to gather the moe_norms on all ranks + moe_norms = gather_and_merge(moe_norms) + pass + + all_norms.update(fsdp_norms) + all_norms.update(moe_norms) + + return all_norms diff --git a/torchtitan/experiments/distributed_scion/norm_helper.py b/torchtitan/experiments/distributed_scion/norm_helper.py new file mode 100644 index 000000000..fe66b391f --- /dev/null +++ b/torchtitan/experiments/distributed_scion/norm_helper.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +from torch.distributed.tensor import DTensor + + +@torch.no_grad() +def rms_to_rms_norm(W): + """ + Note: + Be aware that ``fan_in`` and ``fan_out`` are calculated assuming + that the weight matrix is used in a transposed manner, + (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). + This is important for correct initialization. + If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, + pass in a transposed weight matrix, i.e. ``nn.init.xavier_uniform_(w.T, ...)``. + """ + assert W.ndim == 2, "operator norm can only be applied to matrices" + norm = torch.linalg.norm(W.to(torch.float32), ord=2, dtype=torch.float32) + fan_out, fan_in = W.shape + scale = math.sqrt(fan_in / fan_out) + norm *= scale + return norm + + +@torch.no_grad() +def l1_to_rms_norm(W): + assert W.ndim == 2, "operator norm can only be applied to matrices" + norm = torch.max( + torch.linalg.norm(W.to(torch.float32), ord=2, dim=0, dtype=torch.float32) + ) + scale = torch.sqrt(torch.tensor(W.shape[0], dtype=W.dtype, device=W.device)) + norm /= scale + return norm + + +@torch.no_grad() +def rms_to_l1_norm(W): + assert W.ndim == 2, "operator norm can only be applied to matrices" + norm = torch.max( + torch.linalg.norm(W.to(torch.float32), ord=2, dim=1, dtype=torch.float32) + ) + scale = torch.sqrt(torch.tensor(W.shape[1], dtype=W.dtype, device=W.device)) + norm *= scale + return norm + + +@torch.no_grad() +def supremum_norm(x): + return x.abs().max() + + +@torch.no_grad() +def condition_number(W): + assert W.ndim == 2, "condition number calculation can only be applied to matrices" + S = torch.linalg.svdvals(W.to(torch.float32), driver="gesvd") + return S[0] / S[-1] + + +@torch.no_grad() +def frobenius_norm(W): + return torch.linalg.norm(W.float(), ord="fro") + + +@torch.no_grad() +def average_entry_size(W): + # https://docs.modula.systems/examples/weight-erasure/ + return frobenius_norm(W) / math.sqrt(W.numel()) + + +@torch.no_grad() +def stable_rank(W): + # https://docs.modula.systems/examples/weight-erasure/ + S = torch.linalg.svdvals(W.to(torch.float32), driver="gesvd") + spec = S[0] + if spec == 0: + return torch.tensor(0.0, device=W.device) + frob_norm = frobenius_norm(W) + return (frob_norm**2) / (spec**2) + + +@torch.no_grad() +def effective_rank(W): + # https://docs.modula.systems/examples/weight-erasure/ + S = torch.linalg.svdvals(W.to(torch.float32), driver="gesvd") + p = (S / (S.sum() + 1e-12)).clamp_min(1e-12) + return torch.exp(-(p * p.log()).sum()) + + +NORM_FUNCTIONS = { + "rms_to_rms": rms_to_rms_norm, + "l1_to_rms": l1_to_rms_norm, + "rms_to_l1": rms_to_l1_norm, + "supremum": supremum_norm, + "condition_number": condition_number, + "frobenius_norm": frobenius_norm, + "average_entry_size": average_entry_size, + "stable_rank": stable_rank, + "effective_rank": effective_rank, +} + + +@torch.no_grad() +@torch.compile(fullgraph=True) +def fused_metrics(W, eps=1e-20): + if W.ndim < 2: + # Operator norms require a matrix. + return {"supremum": W.abs().max()} + + Wf = W.float() + Wf_square = Wf * Wf + fan_out, fan_in = Wf.shape + + sup = Wf.abs().amax() + rowsqsum = Wf_square.sum(1) + colsqsum = Wf_square.sum(0) + + row_l2 = rowsqsum.sqrt() + col_l2 = colsqsum.sqrt() + + l1_to_rms = col_l2.max() / math.sqrt(fan_out) + rms_to_l1 = row_l2.max() * math.sqrt(fan_in) + + S = torch.linalg.svdvals(Wf, driver="gesvd") + + spec = S[0] * math.sqrt(fan_in / fan_out) + + cond = S[0] / (S[-1] + eps) + cond = cond.clamp_min(eps) + + frob_norm = row_l2.norm(p=2) + + spec_unscaled = S[0] + srank = (frob_norm**2) / (spec_unscaled**2 + eps) + srank = srank.clamp_min(eps) + + p = (S / (S.sum() + eps)).clamp_min(eps) + erank = torch.exp(-(p * p.log()).sum()) + + avg_entry = frob_norm / math.sqrt(fan_out * fan_in) + + return { + "rms_to_rms": spec, + "l1_to_rms": l1_to_rms, + "rms_to_l1": rms_to_l1, + "supremum": sup, + "condition_number": cond, + "frobenius_norm": frob_norm, + "average_entry_size": avg_entry, + "stable_rank": srank, + "effective_rank": erank, + } + + +def get_norms_to_log(norms_to_log: str | list[str]) -> list[str]: + """ + Return a list of norms to log. + The following contents in `norms_to_log` are special: + - "default": replaced by + ["rms_to_rms", "l1_to_rms", "rms_to_l1", "supremum", "condition_number"] + - "all" or "everything": log all norms + """ + if isinstance(norms_to_log, str): + norms_to_log = [norms_to_log] + # Remove duplicates while keeping order. + norms_to_log = list(dict.fromkeys(norms_to_log)) + + if "all" in norms_to_log or "everything" in norms_to_log: + return list(NORM_FUNCTIONS.keys()) + + if "default" in norms_to_log: + # Replace the "default" entry. + update_index = norms_to_log.index("default") + norms_to_log = ( + norms_to_log[:update_index] + + [ + "rms_to_rms", + "l1_to_rms", + "rms_to_l1", + "supremum", + "condition_number", + ] + + norms_to_log[update_index + 1 :] + ) + + return norms_to_log + + +def calculate_norm( + W: torch.Tensor, + norms_to_log: list[str] | None = None, + transpose: bool = False, + use_fused_metrics: bool = True, +) -> dict[str, torch.Tensor]: + """ + It is important to note that the order of the norms is the same + as the order of `norms_to_log`. + """ + if norms_to_log is None: + norms_to_log = list(NORM_FUNCTIONS.keys()) + + W = W.to_local() if isinstance(W, DTensor) else W + if transpose: + W = W.transpose(0, 1) + if use_fused_metrics: + norms = fused_metrics(W) + norms = {norm_name: norms[norm_name] for norm_name in norms_to_log} + else: + norms = {norm_name: NORM_FUNCTIONS[W] for norm_name in norms_to_log} + + return norms diff --git a/torchtitan/experiments/distributed_scion/train_configs/debug_model.toml b/torchtitan/experiments/distributed_scion/train_configs/debug_model.toml new file mode 100644 index 000000000..4a4f334de --- /dev/null +++ b/torchtitan/experiments/distributed_scion/train_configs/debug_model.toml @@ -0,0 +1,109 @@ +[job] +dump_folder = "./outputs" +description = "DeepSeek-V3 debug training" +print_args = false +use_for_integration_test = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "debugmodel" +# test tokenizer, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[optimizer] +name = "DistributedScion" +beta1 = 0.0 +beta2 = 0.0 +implementation = "fused" +early_step_in_backward = false +lr = 0.1 +eps = 1e-20 +weight_decay = 0.0 +is_light = false +norm_factor = "spectral" +zeropower_backend = "newtonschulz5" +backend_steps = 5 +momentum = 0.1 +nesterov = false + +[[optimizer.extra_splits_rules]] +lr = 0.1 +str_match = "tok_embeddings.weight" +norm_factor = "embed_sqrt" +backend = "identity" + +[[optimizer.extra_splits_rules]] +lr = 0.1 +str_match = "output.weight" +norm_factor = "unembed_sqrt" +backend = "identity" + +[[optimizer.extra_splits_rules]] +str_match = "\\.router.gate.weight" +norm_factor = "spectral" + +[[optimizer.extra_splits_rules]] +str_match = "ssnorm_scale" +norm_factor = "sign" +backend = "identity" +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 8 +seq_len = 2048 +max_norm = 0.0 # grad norm clipping +steps = 10 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "1F1B" +context_parallel_degree = 1 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] +moe_fqns = ["experts"] diff --git a/torchtitan/experiments/distributed_scion/utils.py b/torchtitan/experiments/distributed_scion/utils.py new file mode 100644 index 000000000..fdb802a3e --- /dev/null +++ b/torchtitan/experiments/distributed_scion/utils.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import re +from typing import Any + +import torch + +from torchtitan.tools.logging import logger + + +def remove_orig_mod_and_weight_for_p_name(name: str) -> str: + """ + Remove "._orig_mod", ".weight", and "._checkpoint_wrapped_module" to + get the clean layer name. + """ + name = re.sub(r"\._orig_mod", "", name) # comes from compiled model + name = re.sub(r"\.weight", "", name) # param.weight + name = re.sub( + r"\._checkpoint_wrapped_module", "", name + ) # comes from activation checkpointing + return name + + +def create_scion_optimizer_kwargs_from_optimizer_config( + optimizer_config, + parallel_dims, +) -> dict[str, Any]: + backend_steps = optimizer_config.backend_steps + zeropower_backend_algorithm = optimizer_config.zeropower_backend + momentum = optimizer_config.momentum + nesterov = optimizer_config.nesterov + is_light = optimizer_config.is_light + weight_decay = optimizer_config.weight_decay + lr = optimizer_config.lr + eps = optimizer_config.eps + if os.environ.get("SCION_DEBUG_GRAD") == "1": + norm_factor = "none" + zeropower_backend_algorithm = "identity" + logger.warning( + '`SCION_DEBUG_GRAD` is set to 1, we will not run SVD and use the "identity" backend' + ) + else: + norm_factor = "spectral" + + optimizer_kwargs = { + "parallel_dims": parallel_dims, + "is_light": is_light, + "weight_decay": weight_decay, + "lr": lr, + "momentum": momentum, + "nesterov": nesterov, + "eps": eps, + "norm_factor": norm_factor, + "backend": zeropower_backend_algorithm, + "backend_steps": backend_steps, + } + + # Add extra_splits_rules if present + if ( + hasattr(optimizer_config, "extra_splits_rules") + and optimizer_config.extra_splits_rules + ): + optimizer_kwargs["extra_splits_rules"] = optimizer_config.extra_splits_rules + + return optimizer_kwargs + + +def create_scion_param_groups( + model: torch.nn.Module, + optimizer_kwargs: dict[str, Any], +) -> tuple[list[dict[str, Any]], dict[str, Any]]: + """ + Create and extract parameter groups for DistributedScion optimizer. + + This function combines parameter group configuration creation and parameter extraction: + 1. Creates parameter group configurations from optimizer_kwargs + 2. Extracts actual parameters from the model based on the configurations + 3. Returns clean kwargs for the optimizer (without extra_splits_rules) + + This function supports a new parameter group configuration system where: + - Default values are taken from the top-level optimizer_kwargs + - Special parameter groups are defined in the 'extra_splits_rules' list + - Each entry in extra_splits_rules can override any default value + + Example configuration: + ```python + optimizer_kwargs = { + "lr": 1e-3, + "weight_decay": 0.1, + "momentum": 0.9, + "norm_factor": "spectral", + "backend": "newtonschulz5", + "backend_steps": 5, + "extra_splits_rules": [ + { + "str_match": "embedding", + "lr": 1e-4, # Override default lr + "norm_factor": "embed_sqrt" # Override default norm_factor + }, + { + "str_match": "router", + "lr": 5e-4, + "backend": "identity" # Override default backend + } + ] + } + ``` + Args: + model: The model to extract parameters from + optimizer_kwargs: Dictionary containing optimizer configuration + + Returns: + Tuple of (parameter_groups, clean_kwargs) where: + - parameter_groups: List of parameter groups with actual parameters + - clean_kwargs: Clean kwargs for the optimizer (without extra_splits_rules) + """ + import functools + import re + from collections import OrderedDict + + # Step 1: Create parameter group configurations + param_groups_config = [] + + # Get default configuration + default_config = { + "lr": optimizer_kwargs.get("lr"), + "weight_decay": optimizer_kwargs.get("weight_decay"), + "momentum": optimizer_kwargs.get("momentum"), + "nesterov": optimizer_kwargs.get("nesterov"), + "eps": optimizer_kwargs.get("eps"), + "norm_factor": optimizer_kwargs.get("norm_factor"), + "backend": optimizer_kwargs.get("backend"), + "backend_steps": optimizer_kwargs.get("backend_steps"), + } + + # Process extra_splits_rules if provided + extra_splits_rules = optimizer_kwargs.get("extra_splits_rules", []) + for param_group in extra_splits_rules: + # Start with default config and override with param_group specific values + group_config = default_config.copy() + group_config.update(param_group) + + # Ensure str_match is present + if "str_match" not in group_config: + logger.warning("extra_splits_rules entry missing 'str_match', skipping") + continue + + # Rename str_match to param_str_match for compatibility + group_config["param_str_match"] = group_config.pop("str_match") + + param_groups_config.append(group_config) + + # Step 2: Extract actual parameters from the model + param_dict = OrderedDict( + (n, p) for n, p in model.named_parameters() if p.requires_grad + ) + params = [] + + for param_group_config in param_groups_config: + # Make a copy to avoid modifying the original + group_config = param_group_config.copy() + str_match = group_config.pop("param_str_match") + filter_fn = functools.partial(re.search, str_match) + param_names = [n for n in param_dict.keys() if filter_fn(n)] + + group_params = { + "params": [param_dict.pop(n) for n in param_names], + "param_names": param_names, + } + assert len(group_params["params"]) == len(group_params["param_names"]) + + if len(param_names) == 0: + logger.warning( + f'Notice: No parameters found for `str_match` "{str_match}" on ' + f"global rank {torch.distributed.get_rank()}" + ) + continue + group_params.update(group_config) + params.append(group_params) + + # Add remaining parameters as the default group + param_names = list(param_dict.keys()) + if param_names: + default_group = { + "params": [param_dict.pop(n) for n in param_names], + "param_names": param_names, + } + # Add default configuration to the default group + default_group.update(default_config) + params.insert(0, default_group) + + # Create clean kwargs for the optimizer (remove extra_splits_rules) + clean_kwargs = { + k: v for k, v in optimizer_kwargs.items() if k != "extra_splits_rules" + } + + # for param in params: + # args = {k: v for k, v in param.items() if k != "params"} + # print(f"param: {args}") + + return params, clean_kwargs