Skip to content

Only apply grouped GEMM padding for MXFP8 and FP8 non-HybridEP cases#2620

Open
danielvegamyhre wants to merge 4 commits intomainfrom
paddingupdate
Open

Only apply grouped GEMM padding for MXFP8 and FP8 non-HybridEP cases#2620
danielvegamyhre wants to merge 4 commits intomainfrom
paddingupdate

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Mar 18, 2026

Context

  • BF16 grouped GEMM no longer requires padding, we can remove it from the BF16 path and only use it for FP8 and MXFP8 grouped GEMMs
  • TorchTitan will now only only contain a torch native "rank major to expert major" permutation impl for BF16 grouped GEMM, and not any extra per group padding kernels/logic for FP8/MXFP8 (these will live in torchao, as the quantization library it is a better home for them).

Summary

There are 7 cases to handle:

  • Case 1: BF16 + NoEP
    • (do nothing)
  • Case 2: BF16 + EP
    • Torch native impl handles permute from rank major to expert major (no padding)
  • Case 3: MXFP8 + No EP
    • Handled with pad/unpad kernels in torchao
  • Case 4: MXFP8 + Standard EP
    • torchao permute_and_pad() if token_group_alignment_size > 0, in ExpertParallel implementation
  • Case 5: MXFP8 + HybridEP
    • HybridEP handles token group padding for MXFP8 grouped GEMM as part of the all2all dispatch
  • Case 6: FP8 + No EP
    • Same as case 3
  • Case 7: FP8 + EP
    • Same as case 4

Misc changes

  • Delete kernels.py
  • Delete tests for those kernels
  • Remove pad_token_groups_for_grouped_mm option from MXFP8ConverterConfig, since we can set it correctly automatically
  • Added debug models for float8 and mxfp8 to config registry to speed up future development

Tests

FP8 tests were done with fp8 grouped mm only, not fp8 linear. Using both I get this weird tyro error?

[rank0]:│ model-converters.converters.0:config was not a match because:                │
[rank0]:│ • Default value Config(enable_fsdp_float8_all_gather=False,                  │
[rank0]:│   precompute_float8_dynamic_scale_for_fsdp=False, recipe_name=None,          │
[rank0]:│   filter_fqns=['output', 'router.gate'], emulate=False) with type Config     │
[rank0]:│   does not match type <class 'torchtitan.components.quantization.float8.Floa │
[rank0]:│   t8GroupedMMConverter.Config'>         

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 18, 2026
@danielvegamyhre
Copy link
Contributor Author

Have not tested yet because of devgpu issues but if you want to take a look feel free @tianyu-l

@danielvegamyhre danielvegamyhre force-pushed the paddingupdate branch 3 times, most recently from f906b3d to f79fff4 Compare March 20, 2026 04:14
@danielvegamyhre
Copy link
Contributor Author

I finished testing @tianyu-l this is ready for review

from torchtitan.tools.utils import _round_up

from .kernels import generate_permute_indices
TOKEN_GROUP_ALIGN_SIZE_M = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should remove this -- setting global variables is error-prone

we should move logic to parallelize functions for various combinations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, i am working on this, the changes are straightforward for standard EP, but for Hybrid EP it seems like it will require (1) refactoring the custom ops, DispatchState etc to pass around the quantization type used, or (2) just a module level variable storing the quantization type, similar to _buffer. I think (2) is a less invasive change, wdyt?

Comment on lines +61 to +65
def maybe_align_num_tokens_for_mxfp8(num_tokens: int) -> int:
"""Round up token count only when MXFP8 group alignment is active."""
if TOKEN_GROUP_ALIGN_SIZE_M != MXFP8_GROUP_ALIGNMENT_SIZE:
return num_tokens
return _round_up(num_tokens, MXFP8_GROUP_ALIGNMENT_SIZE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this logic to hybridep.py, including _round_up (as an inline function) which is currently only used once in this repo

# FP8/MXFP8 require groups to be permuted to expert major order AND padded to
# `alignment_size`.
# Otherwise, we only need to permute to expert major order.
if self.token_group_alignment > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the proper way is to create e.g. FP8ExpertParallel and dispatch to it in parallelize function, instead of making if-else in existing ExpertParallel.

Also the condition should be whether quantization is used, not the token_group_alignment size set from somewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works, did a refactor



# Source: https://github.com/pytorch/torchtitan/pull/2255
def _generate_permute_indices(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you verify that before vs. after, we get bitwise identical results under same seed and determinism?

TOKEN_GROUP_ALIGN_SIZE_M = 8
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]

def indices_padding_wrapper(func: Callable) -> Callable:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this function any more. Please remove this and simplify

# NOTE: If EP is not used, we need to pad the indices
# to prepare for grouped_mm;
# otherwise, EP will handle the padding.
if (
not isinstance(self.w1, DTensor)
# pyrefly: ignore[not-iterable]
or "ep" not in self.w1.device_mesh.mesh_dim_names
):
run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm)
else:
run_experts_fn = _run_experts_grouped_mm
return run_experts_fn(w1, w2, w3, x, num_tokens_per_expert)

@@ -45,10 +45,9 @@ def backward(ctx, grad_output):

def indices_padding_wrapper(func: Callable) -> Callable:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

num_tokens_per_expert_group,
ep_degree,
num_local_experts,
FLOAT8_GROUP_ALIGNMENT_SIZE,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need two different classes? You could just init with different quantization type, which can be used to determine the alignment size, e.g. based on a static dict.

Comment on lines +35 to 36
FLOAT8_GROUP_ALIGNMENT_SIZE = 16
MXFP8_GROUP_ALIGNMENT_SIZE = 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this a dict from quantization type to alignment size

Comment on lines +31 to +34
if find_float8_grouped_mm_config(model_converters):
return QuantizationType.FLOAT8
elif config := find_mxfp8_config(model_converters):
if routed_experts_in_fqns(config.fqns):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to modularize into multiple small functions which are not used elsewhere -- we can make everything in a single util function for now

from torchtitan.protocols import ModelConverter


class QuantizationType(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# as part of the EP implementation.
# Otherwise, if EP is not enabled, we need TorchAO to pad the token groups.
self.pad_token_groups_for_grouped_mm = not parallel_dims.ep_enabled
logger.warning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why it's a warning? sounds like a comment to me, especially when both hybridEP is used this warning would still be there

group: ProcessGroup,
score_before_experts: bool = True,
non_blocking_expert_capacity_factor: float | None = None,
quantization_type: QuantizationType | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hybridep module doesn't need to know the quantization_type. All it needs to know is pad multiple size.

)


class Float8ExpertParallel(BaseExpertParallel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you inherit ExpertParallel instead of BaseExpertParallel, which can save a lot of code?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants