Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
530 changes: 0 additions & 530 deletions tests/unit_tests/test_permute_indices_kernel.py

This file was deleted.

2 changes: 1 addition & 1 deletion torchtitan/components/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ class Config(Configurable.Config):


# Module level global constants
FP8_GROUP_ALIGNMENT_SIZE = 16
FLOAT8_GROUP_ALIGNMENT_SIZE = 16
MXFP8_GROUP_ALIGNMENT_SIZE = 32
Comment on lines +35 to 36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this a dict from quantization type to alignment size

19 changes: 11 additions & 8 deletions torchtitan/components/quantization/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,10 @@
import torch
import torch._inductor.config
import torch.nn as nn
from torchtitan.components.quantization import (
FP8_GROUP_ALIGNMENT_SIZE,
QuantizationConverter,
)
from torchtitan.components.quantization import QuantizationConverter
from torchtitan.distributed import ParallelDims

from torchtitan.models.common.linear import Linear
from torchtitan.models.common.moe.utils import set_token_group_alignment_size_m
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import has_cuda_capability

Expand Down Expand Up @@ -261,9 +257,6 @@ def __init__(
not parallel_dims.cp_enabled
), "Float8 MoE training prototype does not yet support context parallelism"

# For fp8 grouped GEMM, token group sizes must be multiples of 16
# (16 byte alignment / 1 byte per elem = 16 elements)
set_token_group_alignment_size_m(FP8_GROUP_ALIGNMENT_SIZE)
self.enabled = True

def convert(self, model: nn.Module):
Expand Down Expand Up @@ -317,3 +310,13 @@ def find_float8_linear_config(
(c for c in converters if isinstance(c, Float8LinearConverter.Config)),
None,
)


def find_float8_grouped_mm_config(
converters: list,
) -> Float8GroupedMMConverter.Config | None:
"""Find the Float8GroupedMM.Config in a list of converter configs, if any."""
return next(
(c for c in converters if isinstance(c, Float8GroupedMMConverter.Config)),
None,
)
47 changes: 33 additions & 14 deletions torchtitan/components/quantization/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,6 @@ class Config(QuantizationConverter.Config):
This is a prototype feature that requires the torchao nightly build.
"""

pad_token_groups_for_grouped_mm: bool = True
"""
Boolean indicating if token group sizes should be padded to multiple of 32 (MXFP8 scaling block size)
for compatibility with quantization kernels. Default is true.

If using HybridEP, set to false. HybridEP automatically performs this padding as part of the
all-to-all dispatch step, so running the padding/unpadding kernels would incur unnecessary extra overhead.
"""

def __init__(
self,
config: Config,
Expand All @@ -66,6 +57,7 @@ def __init__(
model_compile_enabled: bool,
):
self.enabled = False
self.convert_grouped_gemms = False

# Ensure minimum torchao versions
if find_spec("torchao") is None:
Expand All @@ -78,21 +70,38 @@ def __init__(
10, 0
), "MXFP8 is only supported on SM100 or architectures"

# Avoids confusing bugs where fqns is a string we iterate
# over each character instead of iterating each full string in a list
assert isinstance(
config.fqns, list
), "MXFP8Converter.Config.fqns must be a Python list of strings"

# Warn user if torch.compile is not enabled
if not model_compile_enabled:
logger.warning(
"torch.compile enablement is required for highest performance of MXFP8 dynamic quantization."
)

logger.info("MXFP8 MoE training enabled")

# If EP is enabled, TorchTitan handles the token group padding for MXFP8 grouped GEMM
# as part of the EP implementation.
# Otherwise, if EP is not enabled, we need TorchAO to pad the token groups.
self.pad_token_groups_for_grouped_mm = not parallel_dims.ep_enabled
logger.warning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why it's a warning? sounds like a comment to me, especially when both hybridEP is used this warning would still be there

"For applying MXFP8 to MoE grouped GEMMs, use HybridEP with MXFP8 for best performance. "
"This fuses token group padding for MXFP8 grouped GEMM into the all-to-all dispatch, "
"substantially improving performance."
)

self.config = config
self.enabled = True
logger.info("MXFP8 MoE training enabled")

def convert(self, model: nn.Module):
"""
Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor.
This will use low precision grouped GEMMs with dynamic quantization using the specified MX dtype,
rather than the default high precision grouped GEMMs, for the target MoE FQNs.
Mutates the model inplace replacing instances of nn.Parameter with MXFP8TrainingWeightWrapperTensor
for the target FQNs. This will dispatch linear and grouped_mm ops to the appropriate autograd
functions that implement dynamic MXFP8 quantization + MXFP8 linear/grouped_mm ops.
"""
if not self.enabled:
return
Expand All @@ -119,7 +128,7 @@ def module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
recipe = MXFP8TrainingRecipe(self.config.recipe_name)
mxfp8_op_config = MXFP8TrainingOpConfig.from_recipe(recipe)
mxfp8_op_config.pad_token_groups_for_grouped_mm = (
self.config.pad_token_groups_for_grouped_mm
self.pad_token_groups_for_grouped_mm
)

quantize_(model, config=mxfp8_op_config, filter_fn=module_filter_fn)
Expand All @@ -138,3 +147,13 @@ def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
MXFP8 training doesn't require any post-optimizer hooks at the moment
"""
return


def find_mxfp8_config(
converters: list,
) -> MXFP8Converter.Config | None:
"""Find the MXFP8Converter.Config in a list of converter configs, if any."""
return next(
(c for c in converters if isinstance(c, MXFP8Converter.Config)),
None,
)
28 changes: 28 additions & 0 deletions torchtitan/components/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,35 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from enum import auto, Enum

import torch.nn as nn
from torchtitan.components.quantization.float8 import find_float8_grouped_mm_config
from torchtitan.components.quantization.mx import find_mxfp8_config
from torchtitan.protocols import ModelConverter


class QuantizationType(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FLOAT8 = auto()
MXFP8 = auto()


def routed_experts_in_fqns(fqns: list[str]) -> bool:
"""Helper used to determine if quantization is targeting routed experts (grouped GEMMs)."""
for fqn in fqns:
if "experts" in fqn and "shared_experts" not in fqn:
return True
return False


def get_grouped_mm_quantization_type(
model_converters: list[ModelConverter],
) -> QuantizationType | None:
if find_float8_grouped_mm_config(model_converters):
return QuantizationType.FLOAT8
elif config := find_mxfp8_config(model_converters):
if routed_experts_in_fqns(config.fqns):
Comment on lines +31 to +34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to modularize into multiple small functions which are not used elsewhere -- we can make everything in a single util function for now

return QuantizationType.MXFP8


def module_filter_fn(mod: nn.Module, fqn: str, filter_fqns: list[str]) -> bool:
Expand Down
52 changes: 45 additions & 7 deletions torchtitan/distributed/deepep/hybridep.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,11 @@
)
from torch.distributed import ProcessGroup

from torchtitan.models.common.moe.utils import (
get_mxfp8_pad_multiple,
maybe_align_num_tokens_for_mxfp8,
)

from torchtitan.components.quantization import MXFP8_GROUP_ALIGNMENT_SIZE
from torchtitan.components.quantization.utils import QuantizationType

_buffer: Any = None # Global buffer instance
_quantization_type: QuantizationType | None = None # Current quantization type


class DispatchHandle(OpaqueBase):
Expand Down Expand Up @@ -308,8 +306,6 @@ def _combine_backward(ctx, grad_combined):

# Must pass pad_multiple so backward gradients entering ScaledGroupedMM
# (torchao MXFP8) also have rows aligned to 32.
from torchtitan.models.common.moe.utils import get_mxfp8_pad_multiple

pad_multiple = get_mxfp8_pad_multiple()

grad_x, _, _, _, _ = _buffer.dispatch_with_permute(
Expand Down Expand Up @@ -404,6 +400,7 @@ def dispatch_tokens(
group: ProcessGroup,
score_before_experts: bool = True,
non_blocking_expert_capacity_factor: float | None = None,
quantization_type: QuantizationType | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hybridep module doesn't need to know the quantization_type. All it needs to know is pad multiple size.

) -> tuple[torch.Tensor, torch.Tensor, DispatchState]:
"""Dispatch tokens to experts via HybridEP all-to-all.

Expand All @@ -419,10 +416,15 @@ def dispatch_tokens(
float in (0, 1] = non-blocking mode; pre-sizes the permute output
tensor as num_tokens × ep_size × min(num_local_experts, top_k) × cf,
aligned for MXFP8.
quantization_type: Quantization type (FLOAT8, MXFP8, or None) for
determining padding requirements.

Returns:
(permuted_hidden, tokens_per_expert, state)
"""
global _quantization_type
_quantization_type = quantization_type

non_blocking = non_blocking_expert_capacity_factor is not None

selected_experts_indices = selected_experts_indices.contiguous()
Expand Down Expand Up @@ -474,6 +476,42 @@ def combine_tokens(hidden_states: torch.Tensor, state: DispatchState) -> torch.T
return torch.ops.hybridep.combine(hidden_states, state.handle, state.num_tokens)


def get_mxfp8_pad_multiple() -> int | None:
"""Return the pad_multiple needed for MXFP8 grouped GEMMs, or None if not active.

When _quantization_type is MXFP8, dispatch kernels must pad per-expert token
groups to MXFP8_GROUP_ALIGNMENT_SIZE (32) so the quantisation kernel's
row-count requirement is satisfied.

Returns:
32 if using MXFP8, None otherwise.
"""
return (
MXFP8_GROUP_ALIGNMENT_SIZE
if _quantization_type == QuantizationType.MXFP8
else None
)


def maybe_align_num_tokens_for_mxfp8(num_tokens: int) -> int:
"""Round up token count only when MXFP8 group alignment is active.

Args:
num_tokens: The number of tokens to potentially align.

Returns:
Aligned token count if using MXFP8, original count otherwise.
"""
if _quantization_type != QuantizationType.MXFP8:
return num_tokens
num_tokens_rounded = (
(num_tokens + MXFP8_GROUP_ALIGNMENT_SIZE - 1)
// MXFP8_GROUP_ALIGNMENT_SIZE
* MXFP8_GROUP_ALIGNMENT_SIZE
)
return num_tokens_rounded


__all__ = [
"dispatch_tokens",
"combine_tokens",
Expand Down
Loading
Loading