Skip to content

Conversation

rakkit
Copy link
Contributor

@rakkit rakkit commented Aug 22, 2025

Currently, all MoE models rely on the same forward logic (_run_experts_for_loop or _run_experts_grouped_mm), which is hardcoded to use Swiglu.

This PR allows expert_parallel to accept args, allowing users flexibility to define custom expert models. For example, users could specify a different activation function and implement their own forward function, while still reusing the upstream expert_parallel logic:

@expert_parallel
def cunstom_experts_grouped_mm(
    w1: torch.Tensor,
    w2: torch.Tensor,
    w3: torch.Tensor,
    x: torch.Tensor,
    num_tokens_per_expert: torch.Tensor,
    act_fn: Callable | nn.Module,
) -> torch.Tensor:
    

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 22, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm doing a relatively major refactor in #1569
Would appreciate if you can check the new indices_permutation_wrapper is still OK for you to extend.

@rakkit
Copy link
Contributor Author

rakkit commented Aug 22, 2025

thanks @tianyu-l ! Yes, conceptually I think new indices_permutation_wrapper also works for this extension.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants