-
Notifications
You must be signed in to change notification settings - Fork 453
fix: generalize fused layer global scales for DeepSeek MLA attention #2437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,104 +7,116 @@ | |
| strategies like NVFP4. | ||
| """ | ||
|
|
||
| import logging | ||
|
|
||
| import torch | ||
| from compressed_tensors.offload import align_modules, update_offload_parameter | ||
| from compressed_tensors.quantization import QuantizationStrategy, is_attention_module | ||
| from torch.nn import Linear, Module | ||
| from torch.nn import Module | ||
|
|
||
| __all__ = ["update_fused_layer_weight_global_scales"] | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| def update_fused_layer_weight_global_scales(submodule: torch.nn.Module): | ||
| """ | ||
| When running NVFP4 quantization, update the global scale | ||
| such that q,k,v layers are treated as one tensor with the same | ||
| global_scale and gate_proj/up_proj layers are treated as one tensor | ||
| with the same global scale. This is requirement currently being set | ||
| by vLLM and may be removed in the future OR potentially make it | ||
| an optional step. | ||
| # Fused attention projection groups. | ||
| # Each entry is a list of attribute names that should share the same global scale. | ||
| # The first matching group is used; order matters. | ||
| _ATTENTION_FUSED_GROUPS: list[list[str]] = [ | ||
| # Already-fused QKV (e.g. GPT-NeoX, Falcon) | ||
| ["qkv_proj"], | ||
| # Standard separate Q/K/V projections (Llama, Mistral, Qwen, etc.) | ||
| ["q_proj", "k_proj", "v_proj"], | ||
| # DeepSeek V2/V3 MLA: compressed Q + compressed KV | ||
| ["q_a_proj", "kv_a_proj_with_mqa"], | ||
| # DeepSeek V2/V3 MLA: decompressed Q + decompressed KV | ||
| ["q_b_proj", "kv_b_proj"], | ||
| ] | ||
|
|
||
| :param model: model to quantize | ||
|
|
||
| def _fuse_global_scales(layers: list[Module]): | ||
| """ | ||
| Given a list of Linear-like modules, set all of their | ||
| ``weight_global_scale`` parameters to the element-wise minimum. | ||
| """ | ||
| with align_modules(layers): | ||
| global_scale = torch.min( | ||
| torch.cat([layer.weight_global_scale.data for layer in layers]) | ||
| ).reshape([1]) | ||
|
|
||
| def _is_mlp_module(module: Module): | ||
| return "mlp" in module.__class__.__name__.lower() and ( | ||
| hasattr(module, "gate_proj") and hasattr(module, "up_proj") | ||
| ) | ||
| for layer in layers: | ||
| update_offload_parameter(layer, "weight_global_scale", global_scale) | ||
|
|
||
| def _valid_tensor_group_quant(layer_list: list[Linear]): | ||
| """ | ||
| Return True if all the linear layers in the layer_list are | ||
| TENSOR_GROUP quantized. | ||
| """ | ||
| for layer in layer_list: | ||
| scheme = getattr(layer, "quantization_scheme", None) | ||
| if scheme is None: | ||
| return False | ||
| del global_scale | ||
|
|
||
| weight_quant_args = scheme.weights | ||
|
|
||
| if weight_quant_args is None: | ||
| return False | ||
| def _valid_tensor_group_quant(layer_list: list[Module]) -> bool: | ||
| """ | ||
| Return True if all the modules in *layer_list* are | ||
| TENSOR_GROUP quantized (i.e. they carry an NVFP4-style global scale). | ||
| """ | ||
| for layer in layer_list: | ||
| scheme = getattr(layer, "quantization_scheme", None) | ||
| if scheme is None: | ||
| return False | ||
|
|
||
| if weight_quant_args.strategy != QuantizationStrategy.TENSOR_GROUP: | ||
| return False | ||
| return True | ||
| weight_quant_args = scheme.weights | ||
| if weight_quant_args is None: | ||
| return False | ||
|
|
||
| if is_attention_module(submodule): | ||
| # already fused/treated as one layer | ||
| if hasattr(submodule, "qkv_proj"): | ||
| return | ||
| if weight_quant_args.strategy != QuantizationStrategy.TENSOR_GROUP: | ||
| return False | ||
| return True | ||
|
|
||
| # not traditional attention (TODO: MLA) | ||
| if not ( | ||
| hasattr(submodule, "q_proj") | ||
| and hasattr(submodule, "k_proj") | ||
| and hasattr(submodule, "v_proj") | ||
| ): | ||
| return | ||
|
|
||
| if not _valid_tensor_group_quant( | ||
| [submodule.q_proj, submodule.k_proj, submodule.v_proj] | ||
| ): | ||
| return | ||
| def update_fused_layer_weight_global_scales(submodule: torch.nn.Module): | ||
| """ | ||
| When running NVFP4 quantization, update the global scale | ||
| such that fused projection layers are treated as one tensor with the same | ||
| global_scale. Specifically: | ||
|
|
||
| with align_modules([submodule.q_proj, submodule.v_proj, submodule.k_proj]): | ||
| global_scale = torch.min( | ||
| torch.cat( | ||
| ( | ||
| submodule.q_proj.weight_global_scale.data, | ||
| submodule.k_proj.weight_global_scale.data, | ||
| submodule.v_proj.weight_global_scale.data, | ||
| ) | ||
| ) | ||
| ).reshape([1]) | ||
| * **Attention**: q/k/v projections (or MLA-style compressed projections | ||
| like ``q_a_proj``/``kv_a_proj_with_mqa``) share one global scale. | ||
| * **MLP**: gate_proj and up_proj share one global scale. | ||
|
|
||
| update_offload_parameter(submodule.k_proj, "weight_global_scale", global_scale) | ||
| update_offload_parameter(submodule.q_proj, "weight_global_scale", global_scale) | ||
| update_offload_parameter(submodule.v_proj, "weight_global_scale", global_scale) | ||
| This is a requirement currently being set by vLLM and may be removed in | ||
| the future OR potentially made an optional step. | ||
|
|
||
| del global_scale | ||
| :param submodule: a single sub-module of the model (attention or MLP block) | ||
| """ | ||
|
|
||
| def _is_mlp_module(module: Module) -> bool: | ||
| return "mlp" in module.__class__.__name__.lower() and ( | ||
| hasattr(module, "gate_proj") and hasattr(module, "up_proj") | ||
| ) | ||
|
|
||
| # --- Attention fused groups --- | ||
| if is_attention_module(submodule): | ||
| for group in _ATTENTION_FUSED_GROUPS: | ||
| layers = [ | ||
| getattr(submodule, name) | ||
| for name in group | ||
| if hasattr(submodule, name) | ||
| ] | ||
| # Only fuse when ALL names in the group are present | ||
| if len(layers) != len(group): | ||
| continue | ||
|
|
||
| # Skip single-projection groups (already fused, e.g. qkv_proj) | ||
| if len(layers) <= 1: | ||
| return | ||
|
|
||
| if not _valid_tensor_group_quant(layers): | ||
| return | ||
|
|
||
| _fuse_global_scales(layers) | ||
| return # only the first matching group applies | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When an MLA attention module contains both Useful? React with 👍 / 👎. |
||
|
|
||
|
Comment on lines
+93
to
+113
|
||
| # --- MLP fused groups --- | ||
| if _is_mlp_module(submodule): | ||
| if not _valid_tensor_group_quant([submodule.gate_proj, submodule.up_proj]): | ||
| return | ||
| gate_proj = getattr(submodule, "gate_proj") | ||
| up_proj = getattr(submodule, "up_proj") | ||
|
|
||
| with align_modules([submodule.gate_proj, submodule.up_proj]): | ||
| global_scale = torch.min( | ||
| torch.cat( | ||
| ( | ||
| submodule.gate_proj.weight_global_scale.data, | ||
| submodule.up_proj.weight_global_scale.data, | ||
| ) | ||
| ) | ||
| ).reshape([1]) | ||
|
|
||
| update_offload_parameter( | ||
| submodule.gate_proj, | ||
| "weight_global_scale", | ||
| global_scale, | ||
| ) | ||
| update_offload_parameter(submodule.up_proj, "weight_global_scale", global_scale) | ||
| if not _valid_tensor_group_quant([gate_proj, up_proj]): | ||
| return | ||
|
|
||
| del global_scale | ||
| _fuse_global_scales([gate_proj, up_proj]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_ATTENTION_FUSED_GROUPSdoesn't account for the DeepSeek variant where the compressed Q projection is namedq_proj(repo already notes some DeepSeek models useq_projinstead ofq_a_proj). In that case an attention module withq_proj+kv_a_proj_with_mqawill match no group and scales won't be fused. Consider adding an additional group for["q_proj", "kv_a_proj_with_mqa"](ordered before the standard[q_proj,k_proj,v_proj]group, since it won't fully match anyway) or otherwise handling this alias explicitly.