Skip to content

Commit d655e16

Browse files
Make token group alignment size configurable (#1503)
## Summary - For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. ## Test plan - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
1 parent cf30b29 commit d655e16

File tree

4 files changed

+56
-5
lines changed

4 files changed

+56
-5
lines changed

torchtitan/components/quantization/float8.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
from torchtitan.config.job_config import Float8, JobConfig
1212
from torchtitan.distributed import ParallelDims
13+
from torchtitan.experiments.llama4.infra.expert_parallel import (
14+
set_token_group_alignment_size_m,
15+
)
1316
from torchtitan.protocols.model_converter import (
1417
ModelConverter,
1518
register_model_converter,
@@ -66,6 +69,10 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
6669
job_config.parallelism.context_parallel_degree == 1
6770
), "Float8 MoE training prototype does not yet support context parallelism"
6871

72+
# For fp8 grouped GEMM, token group sizes must be multiples of 16
73+
# (16 byte alignment / 1 byte per elem = 16 elements)
74+
set_token_group_alignment_size_m(16)
75+
6976
if float8_config.recipe_name is not None:
7077
assert not float8_config.enable_fsdp_float8_all_gather, (
7178
"using `float8_config.enable_fsdp_float8_all_gather` together "

torchtitan/components/quantization/mx.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5959
and job_config.parallelism.tensor_parallel_degree > 1
6060
), "TP not yet supported with torch.compile for mxfp8"
6161

62+
# For MoE training with mxfp8, token group sizes must be multiples of 32
63+
if job_config.mx.moe_fqns_prototype:
64+
from torchtitan.experiments.llama4.infra.expert_parallel import (
65+
set_token_group_alignment_size,
66+
)
67+
68+
mxfp8_block_size = 32
69+
set_token_group_alignment_size(mxfp8_block_size)
70+
logger.info(f"Setting token group alignment size to {mxfp8_block_size}")
71+
6272
# Configure MXFP8
6373
from torchao.prototype.mx_formats.config import (
6474
MXFP8Dim1CastKernelChoice,

torchtitan/config/job_config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,12 +567,19 @@ class MX:
567567

568568
filter_fqns: list[str] = field(default_factory=lambda: ["output"])
569569
"""
570-
Comma-separated list of fully qualified names of modules to skip applying mxfloat8 training to.
570+
Comma-separated list of fully qualified names of modules to skip applying mxfp8 training to.
571571
nn.Linear modules with any dim size not divisible by 16 are also always skipped due to hardware requirements.
572572
By default we always skip the output layer.
573573
Example: --mx.filter_fqns "attention.wq,attention.wk,attention.wv,output"
574574
"""
575575

576+
moe_fqns_prototype: list[str] | str = field(default_factory=list)
577+
"""
578+
Comma-separated list of fully qualified names of MoE modules to apply mxfp8 training to.
579+
This is a prototype feature that requires the torchao nightly build.
580+
Example: --mx.moe_fqns_prototype="experts"
581+
"""
582+
576583

577584
@dataclass
578585
class Comm:

torchtitan/experiments/llama4/infra/expert_parallel.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
from functools import partial
9-
from typing import Callable
9+
from typing import Callable, Literal
1010

1111
import torch
1212
import torch.distributed as dist
@@ -24,6 +24,33 @@
2424
from torch.distributed.tensor.placement_types import Placement
2525

2626

27+
TOKEN_GROUP_ALIGN_SIZE_M = 8
28+
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
29+
30+
31+
def set_token_group_alignment_size_m(
32+
alignment_size: ValidTokenGroupAlignmentSize,
33+
) -> None:
34+
"""
35+
Set the token group alignment size for token groups in MoE. This is implemented by
36+
padding each token group size to the next multiple of TOKEN_GROUP_ALIGN_SIZE_M.
37+
38+
Valid values are: 8, 16, or 32.
39+
Different values are needed for different cases:
40+
41+
* For bf16, 8 is enough (16 byte alignment / 2 bytes per elem = 8 elements).
42+
* For fp8, 16 byte alignment / 1 byte per elem = 16 elements.
43+
* For mxfp8, we need 32 (or block_size) because scaling block size is (1 x 32),
44+
so when doing per-token-group quantization on each logically distinct subtensor,
45+
we need to ensure the contracting dim is divisible by block_size.
46+
In the backward pass, grad_weight = (grad_output_t @ input).t() has gemm dims
47+
of (N, M) @ (M, K) so M is the contracting dim, and group offsets are along M,
48+
so we need 32 element alignment.
49+
"""
50+
global TOKEN_GROUP_ALIGN_SIZE_M
51+
TOKEN_GROUP_ALIGN_SIZE_M = alignment_size
52+
53+
2754
# implementation of Tensor Parallel for the GroupedExperts in MoE
2855
class TensorParallel(ParallelStyle):
2956
def _partition_fn(self, name, module, device_mesh):
@@ -251,6 +278,7 @@ def wrapper(
251278
x: torch.Tensor,
252279
num_tokens_per_expert: torch.Tensor | None = None,
253280
) -> torch.Tensor:
281+
global TOKEN_GROUP_ALIGN_SIZE_M
254282
if isinstance(w1, DTensor):
255283
w1 = w1.to_local()
256284
w2 = w2.to_local()
@@ -264,7 +292,6 @@ def wrapper(
264292
experts_per_ep_rank = w1.shape[0]
265293
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank
266294

267-
ALIGN_SIZE_M = 16
268295
with torch.no_grad():
269296
(
270297
permuted_indices,
@@ -274,8 +301,8 @@ def wrapper(
274301
num_tokens_per_expert,
275302
experts_per_ep_rank,
276303
num_ep_ranks,
277-
x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M,
278-
ALIGN_SIZE_M,
304+
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M,
305+
TOKEN_GROUP_ALIGN_SIZE_M,
279306
)
280307

281308
x = torch.vstack((x, x.new_zeros((x.shape[-1]))))

0 commit comments

Comments
 (0)