Skip to content

Commit 464d000

Browse files
Fix KV cache calibration for attention modules not named self_attn (#2477)
SUMMARY: `_apply_kv_cache_scheme` (in compressed-tensors) discovers attention modules via `is_attention_module()`, which is name-agnostic. However, `start_calibration` only iterates modules matching `resolved_targets`, which includes `"re:.*self_attn$"` for KV cache. This regex misses attention modules with different names (e.g. `"attention"`, `"self_attention"`), leaving their observers uninitialized and KV cache scales as garbage values. Add a fallback pass in `start_calibration` and `end_calibration` that uses `is_attention_module()` to catch any attention modules missed by the regex. Gated by `kv_cache_scheme is not None` so there is zero cost when KV cache quantization is not used. This addresses the existing TODO: "decouple reliance on this regex for matching attention". TEST PLAN: - Added unit test with a stub model whose attention modules are named `attention` (not `self_attn`). Verifies observers are initialized, hooks are registered, and modules are frozen correctly. - All 51 existing quantization modifier tests pass. --------- Signed-off-by: Jonathan Chang <changjonathanc@users.noreply.github.com> Signed-off-by: Jonathan Chang <31893406+changjonathanc@users.noreply.github.com> Signed-off-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com> Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> Co-authored-by: Jonathan Chang <changjonathanc@users.noreply.github.com> Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com> Co-authored-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 8d9693d commit 464d000

File tree

2 files changed

+128
-1
lines changed

2 files changed

+128
-1
lines changed

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
is_preset_scheme,
2020
preset_name_to_scheme,
2121
)
22+
from compressed_tensors.quantization.utils import KV_CACHE_TARGETS
2223
from compressed_tensors.utils import match_named_modules, update_offload_parameter
2324
from pydantic import Field, PrivateAttr, field_validator
2425
from torch.utils.hooks import RemovableHandle
@@ -208,7 +209,7 @@ def resolved_targets(self) -> Set[str]:
208209

209210
if self.resolved_config.kv_cache_scheme is not None:
210211
# TODO: decouple reliance on this regex for matching attention
211-
targets.add("re:.*self_attn$")
212+
targets.update(KV_CACHE_TARGETS)
212213

213214
return targets
214215

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""
2+
Test that start_calibration initializes KV cache observers on attention modules
3+
regardless of their name in the module tree.
4+
5+
compressed_tensors' _apply_kv_cache_scheme discovers attention modules via
6+
is_attention_module() (name-agnostic), but QuantizationMixin.start_calibration
7+
previously relied on resolved_targets which includes "re:.*self_attn$". This
8+
regex fails for models that name their attention differently (e.g. "attention",
9+
"self_attention"). The fix adds an is_attention_module() fallback pass.
10+
"""
11+
12+
import torch.nn as nn
13+
from compressed_tensors.quantization import (
14+
QuantizationArgs,
15+
QuantizationStatus,
16+
apply_quantization_config,
17+
is_attention_module,
18+
)
19+
from transformers import PretrainedConfig
20+
21+
from llmcompressor.modifiers.quantization.quantization import QuantizationModifier
22+
23+
24+
class _StubAttention(nn.Module):
25+
"""Minimal attention module recognized by is_attention_module().
26+
27+
is_attention_module checks: "attention" in class name (lowercase) and
28+
hasattr(k_proj) or hasattr(v_proj).
29+
"""
30+
31+
def __init__(self, dim: int = 16):
32+
super().__init__()
33+
self.q_proj = nn.Linear(dim, dim, bias=False)
34+
self.k_proj = nn.Linear(dim, dim, bias=False)
35+
self.v_proj = nn.Linear(dim, dim, bias=False)
36+
self.o_proj = nn.Linear(dim, dim, bias=False)
37+
38+
def forward(self, x):
39+
return self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x))
40+
41+
42+
class _StubBlock(nn.Module):
43+
def __init__(self, dim: int = 16):
44+
super().__init__()
45+
# Named "attention" instead of "self_attn" — does NOT match
46+
# the "re:.*self_attn$" regex in resolved_targets.
47+
self.attention = _StubAttention(dim)
48+
self.mlp = nn.Linear(dim, dim)
49+
50+
def forward(self, x):
51+
return self.mlp(self.attention(x))
52+
53+
54+
class _StubModel(nn.Module):
55+
def __init__(self, dim: int = 16, num_layers: int = 2):
56+
super().__init__()
57+
self.config = PretrainedConfig(
58+
num_attention_heads=1,
59+
num_key_value_heads=1,
60+
hidden_size=dim,
61+
)
62+
self.layers = nn.ModuleList([_StubBlock(dim) for _ in range(num_layers)])
63+
64+
def forward(self, x):
65+
for layer in self.layers:
66+
x = layer(x)
67+
return x
68+
69+
70+
def test_attention_module_not_named_self_attn_gets_calibrated():
71+
"""Attention modules named 'attention' (not 'self_attn') must still
72+
get observers and hooks initialized when kv_cache_scheme is set."""
73+
model = _StubModel(dim=16)
74+
modifier = QuantizationModifier(
75+
targets=["Linear"],
76+
kv_cache_scheme=QuantizationArgs(
77+
num_bits=8,
78+
type="float",
79+
strategy="tensor",
80+
dynamic=False,
81+
symmetric=True,
82+
),
83+
)
84+
85+
# Verify our stub is recognized as attention
86+
attn_modules = [
87+
(name, m) for name, m in model.named_modules() if is_attention_module(m)
88+
]
89+
assert len(attn_modules) == 2, "Expected 2 attention modules in _StubModel"
90+
91+
# Verify none of them match the self_attn regex
92+
for name, _ in attn_modules:
93+
assert "self_attn" not in name, f"{name} unexpectedly matches self_attn"
94+
95+
# Apply quantization config (this uses is_attention_module and WILL set scheme)
96+
apply_quantization_config(model, modifier.resolved_config)
97+
98+
# Verify schemes were applied to attention modules by _apply_kv_cache_scheme
99+
for _, m in attn_modules:
100+
assert hasattr(m, "quantization_scheme"), (
101+
"apply_quantization_config should set "
102+
"quantization_scheme on attention modules"
103+
)
104+
105+
# Now run start_calibration — this is what we're testing
106+
modifier.start_calibration(model)
107+
108+
# Verify attention modules got calibration status
109+
for name, m in attn_modules:
110+
assert m.quantization_status == QuantizationStatus.CALIBRATION, (
111+
f"Attention module '{name}' was not calibrated — "
112+
f"start_calibration missed it (status={m.quantization_status})"
113+
)
114+
115+
# Verify hooks were registered for KV cache calibration
116+
assert (
117+
len(modifier._calibration_hooks) > 0
118+
), "Expected calibration hooks to be registered"
119+
120+
# Clean up
121+
modifier.end_calibration(model)
122+
123+
for name, m in attn_modules:
124+
assert (
125+
m.quantization_status == QuantizationStatus.FROZEN
126+
), f"Attention module '{name}' was not frozen after end_calibration"

0 commit comments

Comments
 (0)