Skip to content

Conversation

@krishnateja95
Copy link

@krishnateja95 krishnateja95 commented Nov 10, 2025

Pull Request Description

Purpose

This PR adds quantization configuration remapping support for the GraniteMoeHybrid model to handle the naming difference between Hugging Face model checkpoints and vLLM's internal model structure.

Problem: The GraniteMoeHybrid model uses mamba naming in HF checkpoints but mixer naming in vLLM's implementation (inherited from MambaMixer2). This mismatch causes quantization configurations to fail when targeting specific layers, as the ignore lists and target module names don't match the actual vLLM model structure.

Solution: Implements a maybe_update_quant_config() method that remaps quantization config module names from HF format (mamba) to vLLM format (mixer) using regex patterns. This ensures quantization configurations work correctly with the model's actual layer names.

Changes made:

  • Added maybe_update_quant_config() method to GraniteMoeHybridForCausalLM class
  • Added missing import re statement at the top of the file
  • Method updates both ignore lists and config_groups.targets lists in the quantization config
  • Remapping rules handle patterns like model.layers.{N}.mamba.in_projmodel.layers.{N}.mixer.in_proj

Test Plan

1. Integration Test with FP8 Quantization

# Test with FP8 quantization targeting specific layers
python -m vllm.entrypoints.openai.api_server \
    --model RedHatAI/granite-4.0-h-small-FP8-block \
    --max-model-len 2048

Test Result

Before the fix:

  • Quantization configs with model.layers.{N}.mamba.* patterns would not match any layers
  • Layers would be quantized/ignored incorrectly due to name mismatch
  • Potential errors or warnings during quantization initialization

After the fix:

  • Quantization config module names are automatically remapped from mamba to mixer
  • Both ignore lists and config_groups.targets are updated correctly
  • Quantization applies to the correct layers as intended
  • Model loads and runs successfully with quantization enabled

Additional Notes

  • This fix is specific to GraniteMoeHybrid model and follows similar patterns used in other vLLM model implementations
  • The remapping is applied during model initialization before quantization is configured
  • No changes to model behavior or performance, only fixes configuration name matching
  • Follows vLLM's existing patterns for handling HF/vLLM naming discrepancies

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a quantization bug for the GraniteMoeHybrid model by remapping module names. The approach of using regex to update the quantization configuration is sound.

I've identified a critical bug in the implementation related to incorrect indentation, which could lead to crashes or silent errors in quantization. I've also suggested a refactoring that fixes this bug and improves the code's maintainability and efficiency by removing duplication and pre-compiling regex patterns. Please see the detailed comment.

Comment on lines 663 to 698
remapping_rules = [
# Granite model: mamba -> mixer remapping
(
r"model\.layers\.(\d+)\.mamba\.in_proj",
r"model.layers.\1.mixer.in_proj",
),
(
r"model\.layers\.(\d+)\.mamba\.out_proj",
r"model.layers.\1.mixer.out_proj",
),
]
# Update ignore list
if hasattr(quant_config, "ignore"):
updated_ignore = []
for name in quant_config.ignore:
updated_name = name
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
updated_name = re.sub(pattern, repl, name)
updated_ignore.append(updated_name)
quant_config.ignore = updated_ignore
# Update target list
if hasattr(quant_config, "config_groups"):
config_groups = quant_config.config_groups
for group_name in config_groups:
if "targets" in config_groups[group_name]:
targets = []
for name in config_groups[group_name]["targets"]:
updated_name = name
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
updated_name = re.sub(pattern, repl, name)
targets.append(updated_name)
config_groups[group_name]["targets"] = targets
quant_config.config_groups = config_groups
return quant_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This section has a critical bug and can be refactored for better maintainability and efficiency.

Bug:
The line config_groups[group_name]["targets"] = targets is incorrectly indented. If a config group lacks a "targets" key, it will either be assigned the targets from a previous group or raise an UnboundLocalError. This can cause incorrect quantization or crashes.

Refactoring:
The logic for remapping module names is duplicated for the ignore list and config_groups targets. This can be extracted into a nested helper function. Additionally, pre-compiling the regex patterns improves performance.

The suggested change below fixes the bug and refactors the code for clarity and efficiency.

            remapping_rules = [
                # Granite model: mamba -> mixer remapping
                (
                    re.compile(r"model\.layers\.(\d+)\.mamba\.in_proj"),
                    r"model.layers.\1.mixer.in_proj",
                ),
                (
                    re.compile(r"model\.layers\.(\d+)\.mamba\.out_proj"),
                    r"model.layers.\1.mixer.out_proj",
                ),
            ]

            def _remap_names(names: list[str]) -> list[str]:
                """Helper to remap a list of names using regex rules."""
                remapped_names = []
                for name in names:
                    updated_name = name
                    for pattern, repl in remapping_rules:
                        if pattern.fullmatch(name):
                            updated_name = pattern.sub(repl, name)
                            # Assuming a name can only match one rule.
                            break
                    remapped_names.append(updated_name)
                return remapped_names

            # Update ignore list
            if hasattr(quant_config, "ignore"):
                quant_config.ignore = _remap_names(quant_config.ignore)

            # Update target list
            if hasattr(quant_config, "config_groups"):
                for group_config in quant_config.config_groups.values():
                    if "targets" in group_config:
                        group_config["targets"] = _remap_names(
                            group_config["targets"])
            return quant_config

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

if re.fullmatch(pattern, name):
updated_name = re.sub(pattern, repl, name)
targets.append(updated_name)
config_groups[group_name]["targets"] = targets

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard config_groups assignment behind target check

The config_groups[group_name]["targets"] = targets assignment is aligned with the for group_name loop rather than the preceding if "targets" in config_groups[group_name] block. When a quantization config contains a group that lacks a "targets" key (which is common for groups that only tweak global parameters), the code still executes this assignment and references the local variable targets that was never initialised, raising an UnboundLocalError during model construction. The assignment should live inside the if block so groups without targets are skipped.

Useful? React with 👍 / 👎.

@heheda12345
Copy link
Collaborator

CC @mgoin @Isotr0py @yewentao256

@heheda12345
Copy link
Collaborator

CC @tdoublep

@Isotr0py
Copy link
Member

Isotr0py commented Nov 12, 2025

Seems the root issue is just because MambaMixer2's prefix is written wrong by mistake:

self.mamba = MambaMixer2(
hidden_size=config.hidden_size,
ssm_state_size=config.mamba_d_state,
conv_kernel_size=config.mamba_d_conv,
intermediate_size=config.mamba_expand * config.hidden_size,
use_conv_bias=config.mamba_conv_bias,
use_bias=config.mamba_proj_bias,
n_groups=config.mamba_n_groups,
num_heads=config.mamba_n_heads,
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.mixer",
)

I think we just need to correct it instead of adding a hacky patch. 😅

@krishnateja95
Copy link
Author

Hi @mgoin @Isotr0py @heheda12345, I updated the PR. I removed the may_be_quant() function and updated the load_weights function to make the base model and the quantized model are loaded. The llm-compressor PR (for quantized granite4 model) is here: vllm-project/llm-compressor#2001

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants