Skip to content

Conversation

shubhra
Copy link
Collaborator

@shubhra shubhra commented Sep 17, 2025

SUMMARY:
Code to linearize and quantize the gpt-oss models

Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @shubhra, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new module designed to enable the linearization and quantization of GPT-OSS models, particularly addressing their Mixture of Experts (MoE) architecture. It provides a mechanism to convert the model's MoE layers into a format suitable for quantization, defines the necessary custom MLP components, and integrates a full workflow for applying FP8 dynamic quantization using a specified recipe and calibration dataset. The overall goal is to facilitate the compression of GPT-OSS models for improved efficiency.

Highlights

  • GPT-OSS MoE Linearization: Introduces convert_model_for_quantization_gptoss to transform GPT-OSS's fused-expert Mixture of Experts (MoE) layers into a sequential structure of individual GPTOSSMLP modules, making them compatible with quantization.
  • Custom MLP Implementation: Adds GPTOSSMLP to represent individual expert MLPs, including specific activation functions with clamp and sigmoid operations.
  • Sequential MoE Handling: Implements SequentialGPTOSSMoE to manage the individual GPTOSSMLP experts, copy weights from the original fused MoE, and integrate with the existing router for expert selection.
  • FP8 Dynamic Quantization: Demonstrates the application of an FP8_DYNAMIC quantization scheme using llmcompressor's QuantizationModifier, with specific layers like lm_head, self_attn, attn, attention, and router ignored.
  • Calibration Data Pipeline: Includes a complete pipeline for loading and preprocessing calibration data from the HuggingFaceH4/ultrachat_200k dataset for the quantization process.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a script to linearize and quantize GPT-OSS models by replacing the fused MoE implementation with a sequential one. The approach is sound, but the script has some issues that affect its reusability and performance. Specifically, it contains hardcoded file paths that should be parameterized. The main execution logic should also be placed within an if __name__ == "__main__": block. Additionally, there's an opportunity to optimize a loop in the MoE forward pass for better performance during calibration. I've included suggestions to address these points.

@kylesayrs kylesayrs mentioned this pull request Sep 17, 2025
shubhra and others added 4 commits September 17, 2025 11:36
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@nikhil-arm
Copy link

Thanks for your contribution.
May I know how will gpt_oss modelling know that it has to load linear MOE or interleaved moe? Will you add a marker to config?
https://github.com/vllm-project/vllm/blob/4f02b77de4e794a0d417ed98a26884208f75e043/vllm/model_executor/models/gpt_oss.py#L470

FYI,
I have some changes that creates correct mapping to linearized MOE for gptoss model in vLLM.
I am planning to align them with these changes for model load.

self.intermediate_size = intermediate_size
self.alpha = 1.702
self.limit = 7.0
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add dtype preservation here as well

self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True, dtype=dtype)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True, dtype=dtype)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=True, dtype=dtype)

@shanjiaz
Copy link
Collaborator

shanjiaz commented Sep 18, 2025

Thanks so much for adding this!!

down_w = dwn[i] # [I, H]

mlp = self.experts[i]
mlp.gate_proj.weight.data.copy_(gate_w.T) # [I, H]
Copy link
Collaborator

@shanjiaz shanjiaz Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use update_offload_parameter here?

with align_module_device(experts):
    for expert_index, expert in enumerate(self.experts):
        update_offload_parameter(
            expert.gate_proj,
            "weight",
            experts.gate_up_proj[expert_index, ..., ::2].T,
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants