Skip to content

Commit 2bfcdd8

Browse files
rakkitwang55
andauthored
improve MoE bias update logic in optimizer (#1593)
We put all experts' usage into a buffer such that we only need one reduce rather than #number-of-layers times Additionally, handle cases where tokens per expert are counted twice during full recompute. Co-authored-by: wang55 <[email protected]>
1 parent fd23080 commit 2bfcdd8

File tree

2 files changed

+61
-29
lines changed

2 files changed

+61
-29
lines changed

torchtitan/components/optimizer.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111
import torch.nn as nn
12+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointImpl
1213
from torch.distributed.checkpoint.state_dict import (
1314
get_optimizer_state_dict,
1415
set_optimizer_state_dict,
@@ -340,6 +341,9 @@ def build_optimizers_with_moe_load_balancing(
340341
)
341342

342343
# for MoE auxiliary-loss-free load balancing
344+
def _is_recomputation_enabled(module):
345+
return getattr(module, "checkpoint_impl", None) is CheckpointImpl.NO_REENTRANT
346+
343347
def _update_expert_bias(
344348
model_parts: list[nn.Module],
345349
parallel_dims: ParallelDims,
@@ -349,25 +353,52 @@ def _update_expert_bias(
349353
)
350354
# TODO: Currently this sync is blocking (thus exposed) and happens on the
351355
# default compute stream. Need to assess if this is OK performance-wise.
356+
tokens_per_expert_list = []
352357
for model_part in model_parts:
353358
for transformer_block in model_part.layers.values():
354-
if transformer_block.moe_enabled:
359+
if not transformer_block.moe_enabled:
360+
continue
361+
if transformer_block.moe.load_balance_coeff is None:
362+
return
363+
tokens_per_expert = transformer_block.moe.tokens_per_expert
364+
if _is_recomputation_enabled(transformer_block):
365+
# TODO: This is a hack, we assume with full AC, the tokens_per_expert is counted twice.
366+
# This does not affect to expert choice, but affects the experts usage metrics.
367+
# We divide by 2 to correct for this double-counting due to recomputation
368+
# TODO: new API to help determine if AC is enabled https://github.com/pytorch/pytorch/pull/160888
369+
tokens_per_expert = tokens_per_expert // 2
370+
tokens_per_expert_list.append(tokens_per_expert)
371+
372+
tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list)
373+
374+
if dp_cp_mesh is not None:
375+
# Perform single all-reduce to get global statistics across all processes
376+
pg = dp_cp_mesh.get_group()
377+
torch.distributed.all_reduce(
378+
tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM
379+
)
380+
381+
moe_layer_idx = 0
382+
with torch.no_grad():
383+
for model_part in model_parts:
384+
for transformer_block in model_part.layers.values():
385+
if not transformer_block.moe_enabled:
386+
continue
355387
moe = transformer_block.moe
356-
if moe.load_balance_coeff is None:
357-
return
358-
359-
if dp_cp_mesh is not None:
360-
torch.distributed.all_reduce(
361-
moe.tokens_per_expert, group=dp_cp_mesh.get_group()
362-
)
363-
364-
with torch.no_grad():
365-
expert_bias_delta = moe.load_balance_coeff * torch.sign(
366-
moe.tokens_per_expert.mean() - moe.tokens_per_expert
367-
)
368-
expert_bias_delta = expert_bias_delta - expert_bias_delta.mean()
369-
moe.expert_bias.add_(expert_bias_delta)
370-
moe.tokens_per_expert.zero_()
388+
389+
tokens_per_expert = tokens_per_expert_by_layer[
390+
moe_layer_idx
391+
].float()
392+
moe_layer_idx += 1
393+
394+
# update the expert bias
395+
# this is not exactly the same as https://arxiv.org/pdf/2408.15664 proposed
396+
expert_bias_delta = moe.load_balance_coeff * torch.sign(
397+
tokens_per_expert.mean() - tokens_per_expert
398+
)
399+
expert_bias_delta = expert_bias_delta - expert_bias_delta.mean()
400+
moe.expert_bias.add_(expert_bias_delta)
401+
moe.tokens_per_expert.zero_()
371402

372403
optimizers.register_step_pre_hook(
373404
lambda *args, **kwargs: _update_expert_bias(

torchtitan/models/moe.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,14 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
350350
torch.zeros(num_experts, dtype=torch.float32),
351351
persistent=True,
352352
)
353-
self.register_buffer(
354-
"tokens_per_expert",
355-
torch.zeros(num_experts, dtype=torch.float32),
356-
persistent=False,
357-
)
358353
else:
359354
self.expert_bias = None
355+
# tokens_per_expert will be used to track expert usage and to update the expert bias for load balancing
356+
self.register_buffer(
357+
"tokens_per_expert",
358+
torch.zeros(num_experts, dtype=torch.float32),
359+
persistent=False,
360+
)
360361

361362
def forward(self, x: torch.Tensor) -> torch.Tensor:
362363
"""
@@ -378,12 +379,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
378379
) = self.router(x, self.expert_bias)
379380

380381
# tokens_per_expert will be used to update the expert bias for load balancing.
382+
# and also to count the expert usage
381383
# TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert --
382384
# first in the forward pass, and then in the backward pass. However, this has no
383385
# effect on the expert bias update thanks to the torch.sign() operator.
384-
if self.load_balance_coeff is not None:
385-
with torch.no_grad():
386-
self.tokens_per_expert.add_(num_tokens_per_expert)
386+
with torch.no_grad():
387+
self.tokens_per_expert.add_(num_tokens_per_expert)
387388

388389
# top_scores and token_indices_experts_sorted shape (bs*slen*top_k,)
389390
# num_tokens_per_expert shape (num_experts,)
@@ -444,11 +445,11 @@ def init_weights(
444445
if self.shared_experts is not None:
445446
self.shared_experts.init_weights(init_std)
446447

447-
if self.load_balance_coeff is not None:
448-
with torch.device(buffer_device):
448+
with torch.device(buffer_device):
449+
self.tokens_per_expert = torch.zeros(
450+
self.experts.num_experts, dtype=torch.float32
451+
)
452+
if self.load_balance_coeff is not None:
449453
self.expert_bias = torch.zeros(
450454
self.experts.num_experts, dtype=torch.float32
451455
)
452-
self.tokens_per_expert = torch.zeros(
453-
self.experts.num_experts, dtype=torch.float32
454-
)

0 commit comments

Comments
 (0)