Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion vllm/model_executor/models/granitemoehybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
make_layers,
maybe_prefix,
)

import re

class GraniteMoeHybridMambaDecoderLayer(nn.Module):
def __init__(
Expand Down Expand Up @@ -652,11 +652,62 @@ def get_mamba_state_shape_from_config(
conv_kernel=hf_config.mamba_d_conv,
)

def maybe_update_quant_config(
self, quant_config: QuantizationConfig
) -> QuantizationConfig:
"""
Update quant config so that ignored module and target module names
match the vLLM model names.
Granite model specific: mamba -> mixer remapping.
"""
remapping_rules = [
# Granite model: mamba -> mixer remapping
(
r"model\.layers\.(\d+)\.mamba\.in_proj",
r"model.layers.\1.mixer.in_proj",
),
(
r"model\.layers\.(\d+)\.mamba\.out_proj",
r"model.layers.\1.mixer.out_proj",
),
]
# Update ignore list
if hasattr(quant_config, "ignore"):
updated_ignore = []
for name in quant_config.ignore:
updated_name = name
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
updated_name = re.sub(pattern, repl, name)
updated_ignore.append(updated_name)
quant_config.ignore = updated_ignore
# Update target list
if hasattr(quant_config, "config_groups"):
config_groups = quant_config.config_groups
for group_name in config_groups:
if "targets" in config_groups[group_name]:
targets = []
for name in config_groups[group_name]["targets"]:
updated_name = name
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
updated_name = re.sub(pattern, repl, name)
targets.append(updated_name)
config_groups[group_name]["targets"] = targets

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard config_groups assignment behind target check

The config_groups[group_name]["targets"] = targets assignment is aligned with the for group_name loop rather than the preceding if "targets" in config_groups[group_name] block. When a quantization config contains a group that lacks a "targets" key (which is common for groups that only tweak global parameters), the code still executes this assignment and references the local variable targets that was never initialised, raising an UnboundLocalError during model construction. The assignment should live inside the if block so groups without targets are skipped.

Useful? React with 👍 / 👎.

quant_config.config_groups = config_groups
return quant_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This section has a critical bug and can be refactored for better maintainability and efficiency.

Bug:
The line config_groups[group_name]["targets"] = targets is incorrectly indented. If a config group lacks a "targets" key, it will either be assigned the targets from a previous group or raise an UnboundLocalError. This can cause incorrect quantization or crashes.

Refactoring:
The logic for remapping module names is duplicated for the ignore list and config_groups targets. This can be extracted into a nested helper function. Additionally, pre-compiling the regex patterns improves performance.

The suggested change below fixes the bug and refactors the code for clarity and efficiency.

            remapping_rules = [
                # Granite model: mamba -> mixer remapping
                (
                    re.compile(r"model\.layers\.(\d+)\.mamba\.in_proj"),
                    r"model.layers.\1.mixer.in_proj",
                ),
                (
                    re.compile(r"model\.layers\.(\d+)\.mamba\.out_proj"),
                    r"model.layers.\1.mixer.out_proj",
                ),
            ]

            def _remap_names(names: list[str]) -> list[str]:
                """Helper to remap a list of names using regex rules."""
                remapped_names = []
                for name in names:
                    updated_name = name
                    for pattern, repl in remapping_rules:
                        if pattern.fullmatch(name):
                            updated_name = pattern.sub(repl, name)
                            # Assuming a name can only match one rule.
                            break
                    remapped_names.append(updated_name)
                return remapped_names

            # Update ignore list
            if hasattr(quant_config, "ignore"):
                quant_config.ignore = _remap_names(quant_config.ignore)

            # Update target list
            if hasattr(quant_config, "config_groups"):
                for group_config in quant_config.config_groups.values():
                    if "targets" in group_config:
                        group_config["targets"] = _remap_names(
                            group_config["targets"])
            return quant_config


def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config

if hasattr(vllm_config, "quant_config"):
vllm_config.quant_config = self.maybe_update_quant_config(
vllm_config.quant_config
)

self.model_config = vllm_config.model_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
Expand Down
Loading