From 1728b913988349b06036f1572d1e4642ce2f9706 Mon Sep 17 00:00:00 2001 From: NJX-njx <3771829673@qq.com> Date: Wed, 4 Mar 2026 18:25:06 +0800 Subject: [PATCH] fix: generalize fused layer global scales for DeepSeek MLA attention Fixes #2360 The update_fused_layer_weight_global_scales function previously assumed all attention modules use standard q_proj/k_proj/v_proj projections. This caused an AttributeError for DeepSeek V2/V3 models that use Multi-head Latent Attention (MLA) with different projection names (q_a_proj, kv_a_proj_with_mqa, q_b_proj, kv_b_proj). Changes: - Introduce a configurable _ATTENTION_FUSED_GROUPS registry that lists all known attention projection groups (standard QKV, fused QKV, MLA compressed, MLA decompressed) - Refactor the attention branch to iterate over groups and match the first complete group, gracefully skipping unrecognized modules - Extract _fuse_global_scales helper to remove code duplication - Move _valid_tensor_group_quant to module level for reuse - Remove unused Linear import (use Module instead) - Add comprehensive test suite covering standard, fused, MLA, and edge case scenarios --- src/llmcompressor/modifiers/utils/helpers.py | 166 +++++++------- .../utils/test_fused_global_scales.py | 213 ++++++++++++++++++ 2 files changed, 302 insertions(+), 77 deletions(-) create mode 100644 tests/llmcompressor/modifiers/utils/test_fused_global_scales.py diff --git a/src/llmcompressor/modifiers/utils/helpers.py b/src/llmcompressor/modifiers/utils/helpers.py index cbba632d94..cddac59c75 100644 --- a/src/llmcompressor/modifiers/utils/helpers.py +++ b/src/llmcompressor/modifiers/utils/helpers.py @@ -7,104 +7,116 @@ strategies like NVFP4. """ +import logging + import torch from compressed_tensors.offload import align_modules, update_offload_parameter from compressed_tensors.quantization import QuantizationStrategy, is_attention_module -from torch.nn import Linear, Module +from torch.nn import Module __all__ = ["update_fused_layer_weight_global_scales"] +logger = logging.getLogger(__name__) -def update_fused_layer_weight_global_scales(submodule: torch.nn.Module): - """ - When running NVFP4 quantization, update the global scale - such that q,k,v layers are treated as one tensor with the same - global_scale and gate_proj/up_proj layers are treated as one tensor - with the same global scale. This is requirement currently being set - by vLLM and may be removed in the future OR potentially make it - an optional step. +# Fused attention projection groups. +# Each entry is a list of attribute names that should share the same global scale. +# The first matching group is used; order matters. +_ATTENTION_FUSED_GROUPS: list[list[str]] = [ + # Already-fused QKV (e.g. GPT-NeoX, Falcon) + ["qkv_proj"], + # Standard separate Q/K/V projections (Llama, Mistral, Qwen, etc.) + ["q_proj", "k_proj", "v_proj"], + # DeepSeek V2/V3 MLA: compressed Q + compressed KV + ["q_a_proj", "kv_a_proj_with_mqa"], + # DeepSeek V2/V3 MLA: decompressed Q + decompressed KV + ["q_b_proj", "kv_b_proj"], +] - :param model: model to quantize + +def _fuse_global_scales(layers: list[Module]): + """ + Given a list of Linear-like modules, set all of their + ``weight_global_scale`` parameters to the element-wise minimum. """ + with align_modules(layers): + global_scale = torch.min( + torch.cat([layer.weight_global_scale.data for layer in layers]) + ).reshape([1]) - def _is_mlp_module(module: Module): - return "mlp" in module.__class__.__name__.lower() and ( - hasattr(module, "gate_proj") and hasattr(module, "up_proj") - ) + for layer in layers: + update_offload_parameter(layer, "weight_global_scale", global_scale) - def _valid_tensor_group_quant(layer_list: list[Linear]): - """ - Return True if all the linear layers in the layer_list are - TENSOR_GROUP quantized. - """ - for layer in layer_list: - scheme = getattr(layer, "quantization_scheme", None) - if scheme is None: - return False + del global_scale - weight_quant_args = scheme.weights - if weight_quant_args is None: - return False +def _valid_tensor_group_quant(layer_list: list[Module]) -> bool: + """ + Return True if all the modules in *layer_list* are + TENSOR_GROUP quantized (i.e. they carry an NVFP4-style global scale). + """ + for layer in layer_list: + scheme = getattr(layer, "quantization_scheme", None) + if scheme is None: + return False - if weight_quant_args.strategy != QuantizationStrategy.TENSOR_GROUP: - return False - return True + weight_quant_args = scheme.weights + if weight_quant_args is None: + return False - if is_attention_module(submodule): - # already fused/treated as one layer - if hasattr(submodule, "qkv_proj"): - return + if weight_quant_args.strategy != QuantizationStrategy.TENSOR_GROUP: + return False + return True - # not traditional attention (TODO: MLA) - if not ( - hasattr(submodule, "q_proj") - and hasattr(submodule, "k_proj") - and hasattr(submodule, "v_proj") - ): - return - if not _valid_tensor_group_quant( - [submodule.q_proj, submodule.k_proj, submodule.v_proj] - ): - return +def update_fused_layer_weight_global_scales(submodule: torch.nn.Module): + """ + When running NVFP4 quantization, update the global scale + such that fused projection layers are treated as one tensor with the same + global_scale. Specifically: - with align_modules([submodule.q_proj, submodule.v_proj, submodule.k_proj]): - global_scale = torch.min( - torch.cat( - ( - submodule.q_proj.weight_global_scale.data, - submodule.k_proj.weight_global_scale.data, - submodule.v_proj.weight_global_scale.data, - ) - ) - ).reshape([1]) + * **Attention**: q/k/v projections (or MLA-style compressed projections + like ``q_a_proj``/``kv_a_proj_with_mqa``) share one global scale. + * **MLP**: gate_proj and up_proj share one global scale. - update_offload_parameter(submodule.k_proj, "weight_global_scale", global_scale) - update_offload_parameter(submodule.q_proj, "weight_global_scale", global_scale) - update_offload_parameter(submodule.v_proj, "weight_global_scale", global_scale) + This is a requirement currently being set by vLLM and may be removed in + the future OR potentially made an optional step. - del global_scale + :param submodule: a single sub-module of the model (attention or MLP block) + """ + def _is_mlp_module(module: Module) -> bool: + return "mlp" in module.__class__.__name__.lower() and ( + hasattr(module, "gate_proj") and hasattr(module, "up_proj") + ) + + # --- Attention fused groups --- + if is_attention_module(submodule): + for group in _ATTENTION_FUSED_GROUPS: + layers = [ + getattr(submodule, name) + for name in group + if hasattr(submodule, name) + ] + # Only fuse when ALL names in the group are present + if len(layers) != len(group): + continue + + # Skip single-projection groups (already fused, e.g. qkv_proj) + if len(layers) <= 1: + return + + if not _valid_tensor_group_quant(layers): + return + + _fuse_global_scales(layers) + return # only the first matching group applies + + # --- MLP fused groups --- if _is_mlp_module(submodule): - if not _valid_tensor_group_quant([submodule.gate_proj, submodule.up_proj]): - return + gate_proj = getattr(submodule, "gate_proj") + up_proj = getattr(submodule, "up_proj") - with align_modules([submodule.gate_proj, submodule.up_proj]): - global_scale = torch.min( - torch.cat( - ( - submodule.gate_proj.weight_global_scale.data, - submodule.up_proj.weight_global_scale.data, - ) - ) - ).reshape([1]) - - update_offload_parameter( - submodule.gate_proj, - "weight_global_scale", - global_scale, - ) - update_offload_parameter(submodule.up_proj, "weight_global_scale", global_scale) + if not _valid_tensor_group_quant([gate_proj, up_proj]): + return - del global_scale + _fuse_global_scales([gate_proj, up_proj]) diff --git a/tests/llmcompressor/modifiers/utils/test_fused_global_scales.py b/tests/llmcompressor/modifiers/utils/test_fused_global_scales.py new file mode 100644 index 0000000000..e2a443ed9b --- /dev/null +++ b/tests/llmcompressor/modifiers/utils/test_fused_global_scales.py @@ -0,0 +1,213 @@ +""" +Tests for update_fused_layer_weight_global_scales, including support for +MLA-style attention modules (e.g. DeepSeek V2/V3). + +Covers: +- Standard Q/K/V projections +- Already-fused QKV projection +- DeepSeek MLA projections (q_a_proj + kv_a_proj_with_mqa) +- Modules with no matching fused group (should be a no-op) +- MLP gate_proj/up_proj fusion +""" + +from unittest.mock import patch + +import pytest +import torch +import torch.nn as nn +from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme, QuantizationStrategy + +from llmcompressor.modifiers.utils.helpers import update_fused_layer_weight_global_scales + + +# --------------------------------------------------------------------------- +# Helpers to build mock modules +# --------------------------------------------------------------------------- + + +def _make_linear_with_global_scale( + in_features: int, + out_features: int, + global_scale_value: float, +) -> nn.Linear: + """Create a ``nn.Linear`` with a ``weight_global_scale`` parameter and a + ``quantization_scheme`` that declares TENSOR_GROUP strategy.""" + linear = nn.Linear(in_features, out_features, bias=False) + linear.weight_global_scale = nn.Parameter( + torch.tensor([global_scale_value], dtype=torch.float32) + ) + linear.quantization_scheme = QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + type="float", + strategy=QuantizationStrategy.TENSOR_GROUP, + group_size=16, + symmetric=True, + ), + ) + return linear + + +class _FakeStandardAttention(nn.Module): + """Mock attention module with standard q_proj, k_proj, v_proj.""" + + def __init__(self, dim: int = 64, scales=(1.0, 2.0, 3.0)): + super().__init__() + self.q_proj = _make_linear_with_global_scale(dim, dim, scales[0]) + self.k_proj = _make_linear_with_global_scale(dim, dim, scales[1]) + self.v_proj = _make_linear_with_global_scale(dim, dim, scales[2]) + + +class _FakeFusedQKVAttention(nn.Module): + """Mock attention module with already-fused qkv_proj.""" + + def __init__(self, dim: int = 64, scale: float = 5.0): + super().__init__() + self.qkv_proj = _make_linear_with_global_scale(dim, dim * 3, scale) + + +class _FakeMLAAttention(nn.Module): + """Mock DeepSeek V2/V3 MLA-style attention with q_a_proj + kv_a_proj_with_mqa.""" + + def __init__(self, dim: int = 64, scales=(4.0, 8.0)): + super().__init__() + self.q_a_proj = _make_linear_with_global_scale(dim, dim, scales[0]) + self.kv_a_proj_with_mqa = _make_linear_with_global_scale(dim, dim, scales[1]) + + +class _FakeMLAAttentionFull(nn.Module): + """MLA attention with both compressed and decompressed projections.""" + + def __init__(self, dim: int = 64): + super().__init__() + # compressed + self.q_a_proj = _make_linear_with_global_scale(dim, dim, 4.0) + self.kv_a_proj_with_mqa = _make_linear_with_global_scale(dim, dim, 8.0) + # decompressed + self.q_b_proj = _make_linear_with_global_scale(dim, dim, 2.0) + self.kv_b_proj = _make_linear_with_global_scale(dim, dim, 6.0) + + +class _FakeNoMatchAttention(nn.Module): + """Attention module with non-standard projection names that don't match any group.""" + + def __init__(self, dim: int = 64): + super().__init__() + self.custom_proj_a = _make_linear_with_global_scale(dim, dim, 1.0) + self.custom_proj_b = _make_linear_with_global_scale(dim, dim, 2.0) + + +class FakeMLP(nn.Module): + """MLP module with gate_proj and up_proj.""" + + def __init__(self, dim: int = 64, scales=(3.0, 7.0)): + super().__init__() + self.gate_proj = _make_linear_with_global_scale(dim, dim, scales[0]) + self.up_proj = _make_linear_with_global_scale(dim, dim, scales[1]) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestUpdateFusedLayerWeightGlobalScales: + """Test suite for update_fused_layer_weight_global_scales.""" + + @patch( + "llmcompressor.modifiers.utils.helpers.is_attention_module", + return_value=True, + ) + def test_standard_qkv_projection(self, mock_is_attn): + """Standard q/k/v projections should be fused to the minimum scale.""" + module = _FakeStandardAttention(scales=(1.0, 2.0, 3.0)) + update_fused_layer_weight_global_scales(module) + + expected_min = 1.0 + for proj in [module.q_proj, module.k_proj, module.v_proj]: + assert proj.weight_global_scale.item() == pytest.approx(expected_min) + + @patch( + "llmcompressor.modifiers.utils.helpers.is_attention_module", + return_value=True, + ) + def test_fused_qkv_is_noop(self, mock_is_attn): + """Already-fused qkv_proj should not be touched.""" + module = _FakeFusedQKVAttention(scale=5.0) + update_fused_layer_weight_global_scales(module) + + assert module.qkv_proj.weight_global_scale.item() == pytest.approx(5.0) + + @patch( + "llmcompressor.modifiers.utils.helpers.is_attention_module", + return_value=True, + ) + def test_deepseek_mla_compressed(self, mock_is_attn): + """DeepSeek MLA compressed projections should be fused.""" + module = _FakeMLAAttention(scales=(4.0, 8.0)) + update_fused_layer_weight_global_scales(module) + + expected_min = 4.0 + assert module.q_a_proj.weight_global_scale.item() == pytest.approx( + expected_min + ) + assert module.kv_a_proj_with_mqa.weight_global_scale.item() == pytest.approx( + expected_min + ) + + @patch( + "llmcompressor.modifiers.utils.helpers.is_attention_module", + return_value=True, + ) + def test_deepseek_mla_full_only_first_group_matches(self, mock_is_attn): + """When both compressed and decompressed MLA projections exist, + only the first matching group (compressed) is fused per call. + + In practice the module tree is traversed and the function is called + for each sub-module, so both groups will eventually be handled. + But per single call, only the first match should apply.""" + module = _FakeMLAAttentionFull() + update_fused_layer_weight_global_scales(module) + + # The first matching group is q_a_proj + kv_a_proj_with_mqa + expected_min_compressed = 4.0 + assert module.q_a_proj.weight_global_scale.item() == pytest.approx( + expected_min_compressed + ) + assert module.kv_a_proj_with_mqa.weight_global_scale.item() == pytest.approx( + expected_min_compressed + ) + + @patch( + "llmcompressor.modifiers.utils.helpers.is_attention_module", + return_value=True, + ) + def test_no_matching_attention_group_is_noop(self, mock_is_attn): + """Attention with unrecognized projection names should be a no-op.""" + module = _FakeNoMatchAttention() + update_fused_layer_weight_global_scales(module) + + assert module.custom_proj_a.weight_global_scale.item() == pytest.approx(1.0) + assert module.custom_proj_b.weight_global_scale.item() == pytest.approx(2.0) + + def test_mlp_gate_up_fusion(self): + """MLP gate_proj and up_proj should be fused to the minimum scale.""" + module = FakeMLP(scales=(3.0, 7.0)) + update_fused_layer_weight_global_scales(module) + + expected_min = 3.0 + assert module.gate_proj.weight_global_scale.item() == pytest.approx( + expected_min + ) + assert module.up_proj.weight_global_scale.item() == pytest.approx(expected_min) + + @patch( + "llmcompressor.modifiers.utils.helpers.is_attention_module", + return_value=False, + ) + def test_non_attention_non_mlp_is_noop(self, mock_is_attn): + """A module that is neither attention nor MLP should be untouched.""" + module = nn.Linear(64, 64) + update_fused_layer_weight_global_scales(module) + # No crash, no changes — just verify it's a no-op