Only apply grouped GEMM padding for MXFP8 and FP8 non-HybridEP cases#2620
Only apply grouped GEMM padding for MXFP8 and FP8 non-HybridEP cases#2620danielvegamyhre wants to merge 4 commits intomainfrom
Conversation
|
Have not tested yet because of devgpu issues but if you want to take a look feel free @tianyu-l |
9a0669f to
744ce3a
Compare
4161924 to
df008b3
Compare
f906b3d to
f79fff4
Compare
f79fff4 to
e980189
Compare
|
I finished testing @tianyu-l this is ready for review |
| from torchtitan.tools.utils import _round_up | ||
|
|
||
| from .kernels import generate_permute_indices | ||
| TOKEN_GROUP_ALIGN_SIZE_M = 0 |
There was a problem hiding this comment.
we should remove this -- setting global variables is error-prone
we should move logic to parallelize functions for various combinations
There was a problem hiding this comment.
makes sense, i am working on this, the changes are straightforward for standard EP, but for Hybrid EP it seems like it will require (1) refactoring the custom ops, DispatchState etc to pass around the quantization type used, or (2) just a module level variable storing the quantization type, similar to _buffer. I think (2) is a less invasive change, wdyt?
| def maybe_align_num_tokens_for_mxfp8(num_tokens: int) -> int: | ||
| """Round up token count only when MXFP8 group alignment is active.""" | ||
| if TOKEN_GROUP_ALIGN_SIZE_M != MXFP8_GROUP_ALIGNMENT_SIZE: | ||
| return num_tokens | ||
| return _round_up(num_tokens, MXFP8_GROUP_ALIGNMENT_SIZE) |
There was a problem hiding this comment.
move this logic to hybridep.py, including _round_up (as an inline function) which is currently only used once in this repo
| # FP8/MXFP8 require groups to be permuted to expert major order AND padded to | ||
| # `alignment_size`. | ||
| # Otherwise, we only need to permute to expert major order. | ||
| if self.token_group_alignment > 0: |
There was a problem hiding this comment.
IMO the proper way is to create e.g. FP8ExpertParallel and dispatch to it in parallelize function, instead of making if-else in existing ExpertParallel.
Also the condition should be whether quantization is used, not the token_group_alignment size set from somewhere.
There was a problem hiding this comment.
That works, did a refactor
|
|
||
|
|
||
| # Source: https://github.com/pytorch/torchtitan/pull/2255 | ||
| def _generate_permute_indices( |
There was a problem hiding this comment.
could you verify that before vs. after, we get bitwise identical results under same seed and determinism?
| TOKEN_GROUP_ALIGN_SIZE_M = 8 | ||
| ValidTokenGroupAlignmentSize = Literal[8, 16, 32] | ||
|
|
||
| def indices_padding_wrapper(func: Callable) -> Callable: |
There was a problem hiding this comment.
I don't think we need this function any more. Please remove this and simplify
torchtitan/torchtitan/models/common/moe/moe.py
Lines 120 to 131 in c0c0bf9
| @@ -45,10 +45,9 @@ def backward(ctx, grad_output): | |||
|
|
|||
| def indices_padding_wrapper(func: Callable) -> Callable: | |||
| num_tokens_per_expert_group, | ||
| ep_degree, | ||
| num_local_experts, | ||
| FLOAT8_GROUP_ALIGNMENT_SIZE, |
There was a problem hiding this comment.
why do you need two different classes? You could just init with different quantization type, which can be used to determine the alignment size, e.g. based on a static dict.
| FLOAT8_GROUP_ALIGNMENT_SIZE = 16 | ||
| MXFP8_GROUP_ALIGNMENT_SIZE = 32 |
There was a problem hiding this comment.
make this a dict from quantization type to alignment size
| if find_float8_grouped_mm_config(model_converters): | ||
| return QuantizationType.FLOAT8 | ||
| elif config := find_mxfp8_config(model_converters): | ||
| if routed_experts_in_fqns(config.fqns): |
There was a problem hiding this comment.
no need to modularize into multiple small functions which are not used elsewhere -- we can make everything in a single util function for now
| from torchtitan.protocols import ModelConverter | ||
|
|
||
|
|
||
| class QuantizationType(Enum): |
There was a problem hiding this comment.
you already have https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/quantization/mx.py#L35, can we just use the strings?
| # as part of the EP implementation. | ||
| # Otherwise, if EP is not enabled, we need TorchAO to pad the token groups. | ||
| self.pad_token_groups_for_grouped_mm = not parallel_dims.ep_enabled | ||
| logger.warning( |
There was a problem hiding this comment.
why it's a warning? sounds like a comment to me, especially when both hybridEP is used this warning would still be there
| group: ProcessGroup, | ||
| score_before_experts: bool = True, | ||
| non_blocking_expert_capacity_factor: float | None = None, | ||
| quantization_type: QuantizationType | None = None, |
There was a problem hiding this comment.
hybridep module doesn't need to know the quantization_type. All it needs to know is pad multiple size.
| ) | ||
|
|
||
|
|
||
| class Float8ExpertParallel(BaseExpertParallel): |
There was a problem hiding this comment.
can you inherit ExpertParallel instead of BaseExpertParallel, which can save a lot of code?
Context
Summary
There are 7 cases to handle:
permute_and_pad()if token_group_alignment_size > 0, in ExpertParallel implementationMisc changes
kernels.pypad_token_groups_for_grouped_mmoption from MXFP8ConverterConfig, since we can set it correctly automaticallyTests
FP8 tests were done with fp8 grouped mm only, not fp8 linear. Using both I get this weird tyro error?