-
Notifications
You must be signed in to change notification settings - Fork 453
Fix KV cache calibration for attention modules not named self_attn #2477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e534e49
9e30a28
20930b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -202,7 +202,7 @@ def resolved_targets(self) -> Set[str]: | |
| targets.add(target) | ||
|
|
||
| if self.resolved_config.kv_cache_scheme is not None: | ||
| # TODO: decouple reliance on this regex for matching attention | ||
| # TODO: also apply is_attention_module() fallback in initialize_quantization | ||
| targets.add("re:.*self_attn$") | ||
|
|
||
| return targets | ||
|
|
@@ -226,6 +226,12 @@ def initialize_quantization(self, model: torch.nn.Module): | |
| # disable quantization until calibration | ||
| model.apply(disable_quantization) | ||
|
|
||
| def _start_calibrating_module(self, module: torch.nn.Module): | ||
| """Initialize observers, register calibration hooks, and set status.""" | ||
| self._initialize_observers(module) | ||
| self._calibration_hooks |= self._initialize_hooks(module) | ||
| apply_calibration_status(module) | ||
|
|
||
| def start_calibration(self, model: torch.nn.Module): | ||
| """ | ||
| Attach observers, register activation calibration hooks (including | ||
|
|
@@ -238,9 +244,18 @@ def start_calibration(self, model: torch.nn.Module): | |
| untie_word_embeddings(model) | ||
|
|
||
| for _, module in match_named_modules(model, self.resolved_targets, self.ignore): | ||
| self._initialize_observers(module) | ||
| self._calibration_hooks |= self._initialize_hooks(module) | ||
| apply_calibration_status(module) | ||
| self._start_calibrating_module(module) | ||
|
|
||
| # Fallback: catch attention modules missed by the "re:.*self_attn$" regex. | ||
| if self.resolved_config.kv_cache_scheme is not None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. with the update to targets, ideally we can remove this fallback |
||
| for _, module in model.named_modules(): | ||
| if ( | ||
| is_attention_module(module) | ||
| and hasattr(module, "quantization_scheme") | ||
| and getattr(module, "quantization_status", None) | ||
| != QuantizationStatus.CALIBRATION | ||
| ): | ||
| self._start_calibrating_module(module) | ||
|
|
||
| model.apply(enable_quantization) # quantize at the same time as calibrate | ||
|
|
||
|
|
@@ -255,6 +270,16 @@ def end_calibration(self, model: torch.nn.Module): | |
| for _, module in match_named_modules(model, self.resolved_targets, self.ignore): | ||
| freeze_module_quantization(module) # remove observers | ||
|
|
||
| # Also freeze attention modules missed by the regex fallback in start_calibration. | ||
| if self.resolved_config.kv_cache_scheme is not None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. with update to targets, ideally we can remove this fallback |
||
| for _, module in model.named_modules(): | ||
| if ( | ||
| is_attention_module(module) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note if we need to update |
||
| and getattr(module, "quantization_status", None) | ||
| == QuantizationStatus.CALIBRATION | ||
| ): | ||
| freeze_module_quantization(module) | ||
|
|
||
| model.apply(enable_quantization) # keep quantization enabled | ||
|
|
||
| def has_config(self) -> bool: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| """ | ||
| Test that start_calibration initializes KV cache observers on attention modules | ||
| regardless of their name in the module tree. | ||
|
|
||
| compressed_tensors' _apply_kv_cache_scheme discovers attention modules via | ||
| is_attention_module() (name-agnostic), but QuantizationMixin.start_calibration | ||
| previously relied on resolved_targets which includes "re:.*self_attn$". This | ||
| regex fails for models that name their attention differently (e.g. "attention", | ||
| "self_attention"). The fix adds an is_attention_module() fallback pass. | ||
| """ | ||
|
|
||
| import torch.nn as nn | ||
| from compressed_tensors.quantization import ( | ||
| QuantizationArgs, | ||
| QuantizationStatus, | ||
| apply_quantization_config, | ||
| is_attention_module, | ||
| ) | ||
| from transformers import PretrainedConfig | ||
|
|
||
| from llmcompressor.modifiers.quantization.quantization import QuantizationModifier | ||
|
|
||
|
|
||
| class _StubAttention(nn.Module): | ||
| """Minimal attention module recognized by is_attention_module(). | ||
|
|
||
| is_attention_module checks: "attention" in class name (lowercase) and | ||
| hasattr(k_proj) or hasattr(v_proj). | ||
| """ | ||
|
|
||
| def __init__(self, dim: int = 16): | ||
| super().__init__() | ||
| self.q_proj = nn.Linear(dim, dim, bias=False) | ||
| self.k_proj = nn.Linear(dim, dim, bias=False) | ||
| self.v_proj = nn.Linear(dim, dim, bias=False) | ||
| self.o_proj = nn.Linear(dim, dim, bias=False) | ||
|
|
||
| def forward(self, x): | ||
| return self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x)) | ||
|
|
||
|
|
||
| class _StubBlock(nn.Module): | ||
| def __init__(self, dim: int = 16): | ||
| super().__init__() | ||
| # Named "attention" instead of "self_attn" — does NOT match | ||
| # the "re:.*self_attn$" regex in resolved_targets. | ||
| self.attention = _StubAttention(dim) | ||
| self.mlp = nn.Linear(dim, dim) | ||
|
|
||
| def forward(self, x): | ||
| return self.mlp(self.attention(x)) | ||
|
|
||
|
|
||
| class _StubModel(nn.Module): | ||
| def __init__(self, dim: int = 16, num_layers: int = 2): | ||
| super().__init__() | ||
| self.config = PretrainedConfig( | ||
| num_attention_heads=1, | ||
| num_key_value_heads=1, | ||
| hidden_size=dim, | ||
| ) | ||
| self.layers = nn.ModuleList([_StubBlock(dim) for _ in range(num_layers)]) | ||
|
|
||
| def forward(self, x): | ||
| for layer in self.layers: | ||
| x = layer(x) | ||
| return x | ||
|
|
||
|
|
||
| def test_attention_module_not_named_self_attn_gets_calibrated(): | ||
| """Attention modules named 'attention' (not 'self_attn') must still | ||
| get observers and hooks initialized when kv_cache_scheme is set.""" | ||
| model = _StubModel(dim=16) | ||
| modifier = QuantizationModifier( | ||
| targets=["Linear"], | ||
| kv_cache_scheme=QuantizationArgs( | ||
| num_bits=8, | ||
| type="float", | ||
| strategy="tensor", | ||
| dynamic=False, | ||
| symmetric=True, | ||
| ), | ||
| ) | ||
|
|
||
| # Verify our stub is recognized as attention | ||
| attn_modules = [ | ||
| (name, m) for name, m in model.named_modules() if is_attention_module(m) | ||
| ] | ||
| assert len(attn_modules) == 2, "Expected 2 attention modules in _StubModel" | ||
|
|
||
| # Verify none of them match the self_attn regex | ||
| for name, _ in attn_modules: | ||
| assert "self_attn" not in name, f"{name} unexpectedly matches self_attn" | ||
|
|
||
| # Apply quantization config (this uses is_attention_module and WILL set scheme) | ||
| apply_quantization_config(model, modifier.resolved_config) | ||
|
|
||
| # Verify schemes were applied to attention modules by _apply_kv_cache_scheme | ||
| for _, m in attn_modules: | ||
| assert hasattr(m, "quantization_scheme"), ( | ||
| "apply_quantization_config should set " | ||
| "quantization_scheme on attention modules" | ||
| ) | ||
|
|
||
| # Now run start_calibration — this is what we're testing | ||
| modifier.start_calibration(model) | ||
|
|
||
| # Verify attention modules got calibration status | ||
| for name, m in attn_modules: | ||
| assert m.quantization_status == QuantizationStatus.CALIBRATION, ( | ||
| f"Attention module '{name}' was not calibrated — " | ||
| f"start_calibration missed it (status={m.quantization_status})" | ||
| ) | ||
|
|
||
| # Verify hooks were registered for KV cache calibration | ||
| assert ( | ||
| len(modifier._calibration_hooks) > 0 | ||
| ), "Expected calibration hooks to be registered" | ||
|
|
||
| # Clean up | ||
| modifier.end_calibration(model) | ||
|
|
||
| for name, m in attn_modules: | ||
| assert ( | ||
| m.quantization_status == QuantizationStatus.FROZEN | ||
| ), f"Attention module '{name}' was not frozen after end_calibration" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.