-
Notifications
You must be signed in to change notification settings - Fork 755
Only apply grouped GEMM padding for MXFP8 and FP8 non-HybridEP cases #2620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,15 +49,6 @@ class Config(QuantizationConverter.Config): | |
| This is a prototype feature that requires the torchao nightly build. | ||
| """ | ||
|
|
||
| pad_token_groups_for_grouped_mm: bool = True | ||
| """ | ||
| Boolean indicating if token group sizes should be padded to multiple of 32 (MXFP8 scaling block size) | ||
| for compatibility with quantization kernels. Default is true. | ||
|
|
||
| If using HybridEP, set to false. HybridEP automatically performs this padding as part of the | ||
| all-to-all dispatch step, so running the padding/unpadding kernels would incur unnecessary extra overhead. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: Config, | ||
|
|
@@ -66,6 +57,7 @@ def __init__( | |
| model_compile_enabled: bool, | ||
| ): | ||
| self.enabled = False | ||
| self.convert_grouped_gemms = False | ||
|
|
||
| # Ensure minimum torchao versions | ||
| if find_spec("torchao") is None: | ||
|
|
@@ -78,21 +70,38 @@ def __init__( | |
| 10, 0 | ||
| ), "MXFP8 is only supported on SM100 or architectures" | ||
|
|
||
| # Avoids confusing bugs where fqns is a string we iterate | ||
| # over each character instead of iterating each full string in a list | ||
| assert isinstance( | ||
| config.fqns, list | ||
| ), "MXFP8Converter.Config.fqns must be a Python list of strings" | ||
|
|
||
| # Warn user if torch.compile is not enabled | ||
| if not model_compile_enabled: | ||
| logger.warning( | ||
| "torch.compile enablement is required for highest performance of MXFP8 dynamic quantization." | ||
| ) | ||
|
|
||
| logger.info("MXFP8 MoE training enabled") | ||
|
|
||
| # If EP is enabled, TorchTitan handles the token group padding for MXFP8 grouped GEMM | ||
| # as part of the EP implementation. | ||
| # Otherwise, if EP is not enabled, we need TorchAO to pad the token groups. | ||
| self.pad_token_groups_for_grouped_mm = not parallel_dims.ep_enabled | ||
| logger.warning( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why it's a warning? sounds like a comment to me, especially when both hybridEP is used this warning would still be there |
||
| "For applying MXFP8 to MoE grouped GEMMs, use HybridEP with MXFP8 for best performance. " | ||
| "This fuses token group padding for MXFP8 grouped GEMM into the all-to-all dispatch, " | ||
| "substantially improving performance." | ||
| ) | ||
|
|
||
| self.config = config | ||
| self.enabled = True | ||
| logger.info("MXFP8 MoE training enabled") | ||
|
|
||
| def convert(self, model: nn.Module): | ||
| """ | ||
| Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor. | ||
| This will use low precision grouped GEMMs with dynamic quantization using the specified MX dtype, | ||
| rather than the default high precision grouped GEMMs, for the target MoE FQNs. | ||
| Mutates the model inplace replacing instances of nn.Parameter with MXFP8TrainingWeightWrapperTensor | ||
| for the target FQNs. This will dispatch linear and grouped_mm ops to the appropriate autograd | ||
| functions that implement dynamic MXFP8 quantization + MXFP8 linear/grouped_mm ops. | ||
| """ | ||
| if not self.enabled: | ||
| return | ||
|
|
@@ -119,7 +128,7 @@ def module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: | |
| recipe = MXFP8TrainingRecipe(self.config.recipe_name) | ||
| mxfp8_op_config = MXFP8TrainingOpConfig.from_recipe(recipe) | ||
| mxfp8_op_config.pad_token_groups_for_grouped_mm = ( | ||
| self.config.pad_token_groups_for_grouped_mm | ||
| self.pad_token_groups_for_grouped_mm | ||
| ) | ||
|
|
||
| quantize_(model, config=mxfp8_op_config, filter_fn=module_filter_fn) | ||
|
|
@@ -138,3 +147,13 @@ def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): | |
| MXFP8 training doesn't require any post-optimizer hooks at the moment | ||
| """ | ||
| return | ||
|
|
||
|
|
||
| def find_mxfp8_config( | ||
| converters: list, | ||
| ) -> MXFP8Converter.Config | None: | ||
| """Find the MXFP8Converter.Config in a list of converter configs, if any.""" | ||
| return next( | ||
| (c for c in converters if isinstance(c, MXFP8Converter.Config)), | ||
| None, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,35 @@ | |
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from enum import auto, Enum | ||
|
|
||
| import torch.nn as nn | ||
| from torchtitan.components.quantization.float8 import find_float8_grouped_mm_config | ||
| from torchtitan.components.quantization.mx import find_mxfp8_config | ||
| from torchtitan.protocols import ModelConverter | ||
|
|
||
|
|
||
| class QuantizationType(Enum): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you already have https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/quantization/mx.py#L35, can we just use the strings? |
||
| FLOAT8 = auto() | ||
| MXFP8 = auto() | ||
|
|
||
|
|
||
| def routed_experts_in_fqns(fqns: list[str]) -> bool: | ||
| """Helper used to determine if quantization is targeting routed experts (grouped GEMMs).""" | ||
| for fqn in fqns: | ||
| if "experts" in fqn and "shared_experts" not in fqn: | ||
| return True | ||
| return False | ||
|
|
||
|
|
||
| def get_grouped_mm_quantization_type( | ||
| model_converters: list[ModelConverter], | ||
| ) -> QuantizationType | None: | ||
| if find_float8_grouped_mm_config(model_converters): | ||
| return QuantizationType.FLOAT8 | ||
| elif config := find_mxfp8_config(model_converters): | ||
| if routed_experts_in_fqns(config.fqns): | ||
|
Comment on lines
+31
to
+34
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to modularize into multiple small functions which are not used elsewhere -- we can make everything in a single util function for now |
||
| return QuantizationType.MXFP8 | ||
|
|
||
|
|
||
| def module_filter_fn(mod: nn.Module, fqn: str, filter_fqns: list[str]) -> bool: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,13 +29,11 @@ | |
| ) | ||
| from torch.distributed import ProcessGroup | ||
|
|
||
| from torchtitan.models.common.moe.utils import ( | ||
| get_mxfp8_pad_multiple, | ||
| maybe_align_num_tokens_for_mxfp8, | ||
| ) | ||
|
|
||
| from torchtitan.components.quantization import MXFP8_GROUP_ALIGNMENT_SIZE | ||
| from torchtitan.components.quantization.utils import QuantizationType | ||
|
|
||
| _buffer: Any = None # Global buffer instance | ||
| _quantization_type: QuantizationType | None = None # Current quantization type | ||
|
|
||
|
|
||
| class DispatchHandle(OpaqueBase): | ||
|
|
@@ -308,8 +306,6 @@ def _combine_backward(ctx, grad_combined): | |
|
|
||
| # Must pass pad_multiple so backward gradients entering ScaledGroupedMM | ||
| # (torchao MXFP8) also have rows aligned to 32. | ||
| from torchtitan.models.common.moe.utils import get_mxfp8_pad_multiple | ||
|
|
||
| pad_multiple = get_mxfp8_pad_multiple() | ||
|
|
||
| grad_x, _, _, _, _ = _buffer.dispatch_with_permute( | ||
|
|
@@ -404,6 +400,7 @@ def dispatch_tokens( | |
| group: ProcessGroup, | ||
| score_before_experts: bool = True, | ||
| non_blocking_expert_capacity_factor: float | None = None, | ||
| quantization_type: QuantizationType | None = None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hybridep module doesn't need to know the quantization_type. All it needs to know is pad multiple size. |
||
| ) -> tuple[torch.Tensor, torch.Tensor, DispatchState]: | ||
| """Dispatch tokens to experts via HybridEP all-to-all. | ||
|
|
||
|
|
@@ -419,10 +416,15 @@ def dispatch_tokens( | |
| float in (0, 1] = non-blocking mode; pre-sizes the permute output | ||
| tensor as num_tokens × ep_size × min(num_local_experts, top_k) × cf, | ||
| aligned for MXFP8. | ||
| quantization_type: Quantization type (FLOAT8, MXFP8, or None) for | ||
| determining padding requirements. | ||
|
|
||
| Returns: | ||
| (permuted_hidden, tokens_per_expert, state) | ||
| """ | ||
| global _quantization_type | ||
| _quantization_type = quantization_type | ||
|
|
||
| non_blocking = non_blocking_expert_capacity_factor is not None | ||
|
|
||
| selected_experts_indices = selected_experts_indices.contiguous() | ||
|
|
@@ -474,6 +476,42 @@ def combine_tokens(hidden_states: torch.Tensor, state: DispatchState) -> torch.T | |
| return torch.ops.hybridep.combine(hidden_states, state.handle, state.num_tokens) | ||
|
|
||
|
|
||
| def get_mxfp8_pad_multiple() -> int | None: | ||
| """Return the pad_multiple needed for MXFP8 grouped GEMMs, or None if not active. | ||
|
|
||
| When _quantization_type is MXFP8, dispatch kernels must pad per-expert token | ||
| groups to MXFP8_GROUP_ALIGNMENT_SIZE (32) so the quantisation kernel's | ||
| row-count requirement is satisfied. | ||
|
|
||
| Returns: | ||
| 32 if using MXFP8, None otherwise. | ||
| """ | ||
| return ( | ||
| MXFP8_GROUP_ALIGNMENT_SIZE | ||
| if _quantization_type == QuantizationType.MXFP8 | ||
| else None | ||
| ) | ||
|
|
||
|
|
||
| def maybe_align_num_tokens_for_mxfp8(num_tokens: int) -> int: | ||
| """Round up token count only when MXFP8 group alignment is active. | ||
|
|
||
| Args: | ||
| num_tokens: The number of tokens to potentially align. | ||
|
|
||
| Returns: | ||
| Aligned token count if using MXFP8, original count otherwise. | ||
| """ | ||
| if _quantization_type != QuantizationType.MXFP8: | ||
| return num_tokens | ||
| num_tokens_rounded = ( | ||
| (num_tokens + MXFP8_GROUP_ALIGNMENT_SIZE - 1) | ||
| // MXFP8_GROUP_ALIGNMENT_SIZE | ||
| * MXFP8_GROUP_ALIGNMENT_SIZE | ||
| ) | ||
| return num_tokens_rounded | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "dispatch_tokens", | ||
| "combine_tokens", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make this a dict from quantization type to alignment size