-
Notifications
You must be signed in to change notification settings - Fork 457
Fix KV cache calibration for attention modules not named self_attn #2477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
brian-dellabetta
merged 8 commits into
vllm-project:main
from
changjonathanc:fix/kv-cache-calibration-attention-modules
Mar 27, 2026
+128
−1
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
a055cc6
Fix KV cache calibration for attention modules not named self_attn
changjonathanc 1921df7
simplify: use status check instead of id set, drop redundant hasattr,…
changjonathanc 8a27ff3
Update src/llmcompressor/modifiers/quantization/quantization/mixin.py
changjonathanc 95474f1
remove fallback blocks now that regex covers attention module paths
changjonathanc 3262c7b
Apply suggestion from @brian-dellabetta
brian-dellabetta 1e137c0
use KV_CACHE_TARGETS
brian-dellabetta d7d29b2
Merge branch 'main' into fix/kv-cache-calibration-attention-modules
brian-dellabetta ed82c0c
Merge branch 'main' into fix/kv-cache-calibration-attention-modules
brian-dellabetta File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
126 changes: 126 additions & 0 deletions
126
tests/llmcompressor/modifiers/quantization/test_kv_cache_calibration.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| """ | ||
| Test that start_calibration initializes KV cache observers on attention modules | ||
| regardless of their name in the module tree. | ||
|
|
||
| compressed_tensors' _apply_kv_cache_scheme discovers attention modules via | ||
| is_attention_module() (name-agnostic), but QuantizationMixin.start_calibration | ||
| previously relied on resolved_targets which includes "re:.*self_attn$". This | ||
| regex fails for models that name their attention differently (e.g. "attention", | ||
| "self_attention"). The fix adds an is_attention_module() fallback pass. | ||
| """ | ||
|
|
||
| import torch.nn as nn | ||
| from compressed_tensors.quantization import ( | ||
| QuantizationArgs, | ||
| QuantizationStatus, | ||
| apply_quantization_config, | ||
| is_attention_module, | ||
| ) | ||
| from transformers import PretrainedConfig | ||
|
|
||
| from llmcompressor.modifiers.quantization.quantization import QuantizationModifier | ||
|
|
||
|
|
||
| class _StubAttention(nn.Module): | ||
| """Minimal attention module recognized by is_attention_module(). | ||
|
|
||
| is_attention_module checks: "attention" in class name (lowercase) and | ||
| hasattr(k_proj) or hasattr(v_proj). | ||
| """ | ||
|
|
||
| def __init__(self, dim: int = 16): | ||
| super().__init__() | ||
| self.q_proj = nn.Linear(dim, dim, bias=False) | ||
| self.k_proj = nn.Linear(dim, dim, bias=False) | ||
| self.v_proj = nn.Linear(dim, dim, bias=False) | ||
| self.o_proj = nn.Linear(dim, dim, bias=False) | ||
|
|
||
| def forward(self, x): | ||
| return self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x)) | ||
|
|
||
|
|
||
| class _StubBlock(nn.Module): | ||
| def __init__(self, dim: int = 16): | ||
| super().__init__() | ||
| # Named "attention" instead of "self_attn" — does NOT match | ||
| # the "re:.*self_attn$" regex in resolved_targets. | ||
| self.attention = _StubAttention(dim) | ||
| self.mlp = nn.Linear(dim, dim) | ||
|
|
||
| def forward(self, x): | ||
| return self.mlp(self.attention(x)) | ||
|
|
||
|
|
||
| class _StubModel(nn.Module): | ||
| def __init__(self, dim: int = 16, num_layers: int = 2): | ||
| super().__init__() | ||
| self.config = PretrainedConfig( | ||
| num_attention_heads=1, | ||
| num_key_value_heads=1, | ||
| hidden_size=dim, | ||
| ) | ||
| self.layers = nn.ModuleList([_StubBlock(dim) for _ in range(num_layers)]) | ||
|
|
||
| def forward(self, x): | ||
| for layer in self.layers: | ||
| x = layer(x) | ||
| return x | ||
|
|
||
|
|
||
| def test_attention_module_not_named_self_attn_gets_calibrated(): | ||
| """Attention modules named 'attention' (not 'self_attn') must still | ||
| get observers and hooks initialized when kv_cache_scheme is set.""" | ||
| model = _StubModel(dim=16) | ||
| modifier = QuantizationModifier( | ||
| targets=["Linear"], | ||
| kv_cache_scheme=QuantizationArgs( | ||
| num_bits=8, | ||
| type="float", | ||
| strategy="tensor", | ||
| dynamic=False, | ||
| symmetric=True, | ||
| ), | ||
| ) | ||
|
|
||
| # Verify our stub is recognized as attention | ||
| attn_modules = [ | ||
| (name, m) for name, m in model.named_modules() if is_attention_module(m) | ||
| ] | ||
| assert len(attn_modules) == 2, "Expected 2 attention modules in _StubModel" | ||
|
|
||
| # Verify none of them match the self_attn regex | ||
| for name, _ in attn_modules: | ||
| assert "self_attn" not in name, f"{name} unexpectedly matches self_attn" | ||
|
|
||
| # Apply quantization config (this uses is_attention_module and WILL set scheme) | ||
| apply_quantization_config(model, modifier.resolved_config) | ||
|
|
||
| # Verify schemes were applied to attention modules by _apply_kv_cache_scheme | ||
| for _, m in attn_modules: | ||
| assert hasattr(m, "quantization_scheme"), ( | ||
| "apply_quantization_config should set " | ||
| "quantization_scheme on attention modules" | ||
| ) | ||
|
|
||
| # Now run start_calibration — this is what we're testing | ||
| modifier.start_calibration(model) | ||
|
|
||
| # Verify attention modules got calibration status | ||
| for name, m in attn_modules: | ||
| assert m.quantization_status == QuantizationStatus.CALIBRATION, ( | ||
| f"Attention module '{name}' was not calibrated — " | ||
| f"start_calibration missed it (status={m.quantization_status})" | ||
| ) | ||
|
|
||
| # Verify hooks were registered for KV cache calibration | ||
| assert ( | ||
| len(modifier._calibration_hooks) > 0 | ||
| ), "Expected calibration hooks to be registered" | ||
|
|
||
| # Clean up | ||
| modifier.end_calibration(model) | ||
|
|
||
| for name, m in attn_modules: | ||
| assert ( | ||
| m.quantization_status == QuantizationStatus.FROZEN | ||
| ), f"Attention module '{name}' was not frozen after end_calibration" | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.