From 2bc8578249c3e7a4066c3910aa6a19fb4e814a4a Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Mon, 1 Sep 2025 12:34:19 +0200 Subject: [PATCH] we can set DEBUG_FORCE_LOAD_BALANCED=1 to force each experts get same amount of token --- torchtitan/models/moe.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 8be14ecbf..7d0025c2b 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os from dataclasses import dataclass from typing import Literal @@ -188,6 +189,19 @@ def __init__( self.score_func = score_func self.route_norm = route_norm self.route_scale = route_scale + self.debug_force_load_balanced = bool( + int(os.getenv("DEBUG_FORCE_LOAD_BALANCED", "0")) + ) + + @staticmethod + def uniform_indices( + n_tokens: int, top_k: int, num_experts: int, device + ) -> torch.Tensor: + """Round-robin expert assignment with exact balance each step. + Returns LongTensor of shape (n_tokens, top_k).""" + i = torch.arange(n_tokens, device=device)[:, None] # [N,1] + k = torch.arange(top_k, device=device)[None, :] # [1,K] + return ((i * top_k + k) % num_experts).long() # [N,K] def forward( self, x: torch.Tensor, expert_bias: torch.Tensor | None = None @@ -231,6 +245,13 @@ def forward( scores, k=self.top_k, dim=1 ) + # debug override: balanced round-robin routing + if self.debug_force_load_balanced: + selected_experts_indices = self.uniform_indices( + x.size(0), self.top_k, self.num_experts, x.device + ) + top_scores = scores.gather(dim=1, index=selected_experts_indices) + if self.score_func == "sigmoid" and self.route_norm: denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 top_scores = top_scores / denominator