Skip to content

Fix KV cache calibration for attention modules not named self_attn#2477

Open
changjonathanc wants to merge 3 commits intovllm-project:mainfrom
changjonathanc:fix/kv-cache-calibration-attention-modules
Open

Fix KV cache calibration for attention modules not named self_attn#2477
changjonathanc wants to merge 3 commits intovllm-project:mainfrom
changjonathanc:fix/kv-cache-calibration-attention-modules

Conversation

@changjonathanc
Copy link

SUMMARY:
_apply_kv_cache_scheme (in compressed-tensors) discovers attention modules via is_attention_module(), which is name-agnostic. However, start_calibration only iterates modules matching resolved_targets, which includes "re:.*self_attn$" for KV cache. This regex misses attention modules with different names (e.g. "attention", "self_attention"), leaving their observers uninitialized and KV cache scales as garbage values.

Add a fallback pass in start_calibration and end_calibration that uses is_attention_module() to catch any attention modules missed by the regex. Gated by kv_cache_scheme is not None so there is zero cost when KV cache quantization is not used. This addresses the existing TODO: "decouple reliance on this regex for matching attention".

TEST PLAN:

  • Added unit test with a stub model whose attention modules are named attention (not self_attn). Verifies observers are initialized, hooks are registered, and modules are frozen correctly.
  • All 51 existing quantization modifier tests pass.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical issue in KV cache calibration where attention modules not explicitly named self_attn were being overlooked, leading to incorrect quantization. By introducing a name-agnostic fallback mechanism, the system now reliably identifies and calibrates all attention modules, ensuring consistent and accurate KV cache quantization across diverse model architectures. This enhancement improves the robustness of the quantization process and is backed by new unit tests.

Highlights

  • Improved KV Cache Calibration: Implemented a fallback mechanism in start_calibration to ensure all attention modules, regardless of their naming convention (e.g., "attention" vs. "self_attn"), are correctly identified and initialized for KV cache quantization.
  • Consistent Module Freezing: Extended end_calibration with a corresponding fallback to properly freeze quantization for attention modules that were initialized via the new mechanism.
  • Performance Optimization: Ensured the new fallback logic is only active when KV cache quantization is enabled, preventing any overhead when it's not in use.
  • Enhanced Test Coverage: Added a dedicated unit test to validate the fix, using a stub model with non-standard attention module names to confirm correct calibration behavior.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • src/llmcompressor/modifiers/quantization/quantization/mixin.py
    • Added a fallback loop in start_calibration to initialize observers and hooks for attention modules identified by is_attention_module() but missed by resolved_targets regex, specifically for KV cache.
    • Introduced a corresponding fallback loop in end_calibration to freeze quantization for these attention modules.
  • tests/llmcompressor/modifiers/quantization/test_kv_cache_calibration.py
    • Created a new test file to verify the KV cache calibration fix.
    • Implemented _StubAttention, _StubBlock, and _StubModel classes to simulate a model with attention modules named attention (not self_attn).
    • Added test_attention_module_not_named_self_attn_gets_calibrated to assert that these modules are correctly calibrated and frozen.
Activity
  • No specific activity has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an issue where KV cache calibration was not correctly applied to attention modules unless they were named self_attn. The proposed solution adds a fallback pass in both start_calibration and end_calibration to identify any missed attention modules using the name-agnostic is_attention_module function. This ensures all relevant modules are properly calibrated and frozen. The changes are well-reasoned and are accompanied by a targeted unit test that validates the fix. My review includes one suggestion to refactor a small amount of duplicated code for better maintainability.

@github-actions
Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

compressed_tensors' _apply_kv_cache_scheme uses is_attention_module()
to set quantization_scheme on attention modules (name-agnostic), but
start_calibration only iterates modules matching resolved_targets which
includes "re:.*self_attn$". This regex misses attention modules with
different names (e.g. "attention", "self_attention"), leaving their KV
cache observers uninitialized and scales as garbage values.

Add a fallback pass in start_calibration and end_calibration that uses
is_attention_module() to catch any attention modules missed by the
regex-based target matching. This closes the gap noted in the existing
TODO: "decouple reliance on this regex for matching attention".

Signed-off-by: Jonathan Chang <changjonathanc@users.noreply.github.com>
@changjonathanc changjonathanc force-pushed the fix/kv-cache-calibration-attention-modules branch from d27cc26 to e534e49 Compare March 17, 2026 11:37
@brian-dellabetta
Copy link
Collaborator

However, start_calibration only iterates modules matching resolved_targets, which includes "re:.*self_attn$" for KV cache

Hi @changjonathanc , can you point me to where you're seeing this? These lines in compressed-tensors? Is it just a matter of updating the targets list to include "re:.*(self_attn|attention)$"?

@changjonathanc
Copy link
Author

if self.resolved_config.kv_cache_scheme is not None:
# TODO: decouple reliance on this regex for matching attention
targets.add("re:.*self_attn$")

@brian-dellabetta
Copy link
Collaborator

if self.resolved_config.kv_cache_scheme is not None:
# TODO: decouple reliance on this regex for matching attention
targets.add("re:.*self_attn$")

Hi @changjonathanc , can't we just update the regex to "re:.*(self_attn|attention)$"?

@changjonathanc
Copy link
Author

Hi @brian-dellabetta ,

honestly i'm not 100% sure, so here is what claude told me and i think it makes sense:


When you do KV cache quantization, the library needs to:

  1. Find all attention modules in the model
  2. Attach a quantization scheme to them
  3. Initialize calibration observers/hooks on them

The bug is that steps 1-2 and step 3 use different methods to find the attention modules, so they don't always agree on which modules to process.


Step 1 & 2: How modules get their scheme attached

This happens in compressed-tensors, in _apply_kv_cache_scheme:

  for submodule in model.modules():         # iterate ALL modules
      if is_attention_module(submodule):    # check: is this an attention module?
          submodule.quantization_scheme = scheme

  is_attention_module checks the class name of the module:
  "attention" in module.__class__.__name__.lower()
  # AND has k_proj, v_proj, qkv_proj, or kv_b_proj

So a module whose Python class is named LlamaAttention passes this check — regardless of what it's called inside the model tree (self_attn, attention, mha, whatever).

After this step, any attention module has quantization_scheme set on it.


Step 3: How calibration observers get initialized

This happens in llm-compressor, in start_calibration:

  for _, module in match_named_modules(model, self.resolved_targets, ...):
      self._initialize_observers(module)
      ...

match_named_modules uses self.resolved_targets, which contains:
"re:.*self_attn$" # a regex matching the module's PATH in the tree

This regex matches module paths like model.layers.0.self_attn — it checks the name you'd use to look up the module, not the class name.


The Bug

Imagine a model where the attention module is stored at path layers.0.attention (not self_attn), but the class is still LlamaAttention.


What the PR does

It adds a fallback in start_calibration: after the regex loop, do a second pass using is_attention_module() — the exact same check that _apply_kv_cache_scheme uses — to catch anything the regex missed:

  # First pass: regex-based (existing)
  for _, module in match_named_modules(model, self.resolved_targets, ...):
      self._start_calibrating_module(module)
      initialized_ids.add(id(module))

  # Second pass: class-name-based fallback (new)
  if self.resolved_config.kv_cache_scheme is not None:
      for _, module in model.named_modules():
          if id(module) not in initialized_ids and is_attention_module(module) ...:
              self._start_calibrating_module(module)

Now both sides use is_attention_module() as the source of truth, so they always agree.`

@brian-dellabetta
Copy link
Collaborator

It adds a fallback in start_calibration: after the regex loop, do a second pass using is_attention_module() — the exact same check that _apply_kv_cache_scheme uses — to catch anything the regex missed

@changjonathanc rather than re-applying after the fact, we just need to update the regex used to match on both self_attn and attention. The one i proposed above should be able to handle that without needing custom logic after-the-fact.

Unfortunately there's no guaranteed way to determine if a module is an attention module and we have to rely on these fuzzy matches, so we'll have to update when necessary as naming conventions drift. Can you try with that one-line change to the regex, rather than the changes you have in src in this PR, and see if your tests still pass?

@changjonathanc
Copy link
Author

Hi @brian-dellabetta, I think your suggestion passes the test, but the confusion still exists, where one part of the code patches by the class name of a module, while the other matches the path. wdyt?

@brian-dellabetta
Copy link
Collaborator

brian-dellabetta commented Mar 20, 2026

Hi @brian-dellabetta, I think your suggestion passes the test, but the confusion still exists, where one part of the code patches by the class name of a module, while the other matches the path. wdyt?

Hi @changjonathanc , these are two separate checks:

  1. given a pointer to a torch.nn.Module, how can I know if it is an attention module? We have some fuzzy matching logic to check the class name of the module and if it has attributes like k_proj or qkv_proj. If we see new class names, like I know deepseek v3.2 uses MLA as the class name for multi-latent attention module, we will need to update this check.
  2. given that i want to apply kv_cache quantization to a model, what should the regex be to catch all possible naming conventions that are used, like model.layers.0.self_attn or in the case of deepseek v3.2 it will be model.layers.0.attn, and you hit a case where it's just *.attention.

In case (1) we will have to update the conditional logic to capture any drift in module class naming conventions,
in case (2) we will have to update the regex to capture any drift in model submodule naming conventions.

@changjonathanc
Copy link
Author

Hi @brian-dellabetta, that makes sense. But since compressed-tensors already decided which modules get quantization_scheme via is_attention_module(), llm-compressor is re-deriving that same decision with a regex that can drift independently. We actually hit this exact bug in our codebase and had to monkey-patch it for the same reason. Happy to go with the simpler regex fix if you prefer consistency with the rest of the codebase though.

@brian-dellabetta
Copy link
Collaborator

Hi @brian-dellabetta, that makes sense. But since compressed-tensors already decided which modules get quantization_scheme via is_attention_module(), llm-compressor is re-deriving that same decision with a regex that can drift independently. We actually hit this exact bug in our codebase and had to monkey-patch it for the same reason. Happy to go with the simpler regex fix if you prefer consistency with the rest of the codebase though.

Thanks @changjonathanc for the detail, good to know. I almost replied that a better solution would be to try to consolidate the logic, but that might be better as an RFC. I will raise this issue with the team next week and discuss internally

@brian-dellabetta brian-dellabetta self-assigned this Mar 23, 2026
Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @changjonathanc , I was able to discuss this today internally. The is_attention_module check is needed because compressed-tensors does not have access to targets when applying the kv cache scheme (see here). kv cache scheme must be applied model-wide based on our data model, QuantizationScheme has no notion of targets. To avoid any breaking changes to that API, I suggest we

  1. add a patch to your use case by updating targets, and
  2. update is_attention_module in compressed-tensors, if necessary, if it's not catching your use case.

will ping over vllm slack in case it's easier to discuss over slack

if self.resolved_config.kv_cache_scheme is not None:
# TODO: decouple reliance on this regex for matching attention
# TODO: also apply is_attention_module() fallback in initialize_quantization
targets.add("re:.*self_attn$")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
targets.add("re:.*self_attn$")
targets.add("re:.*(self_attn|attention)$")

self._start_calibrating_module(module)

# Fallback: catch attention modules missed by the "re:.*self_attn$" regex.
if self.resolved_config.kv_cache_scheme is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the update to targets, ideally we can remove this fallback

freeze_module_quantization(module) # remove observers

# Also freeze attention modules missed by the regex fallback in start_calibration.
if self.resolved_config.kv_cache_scheme is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with update to targets, ideally we can remove this fallback

if self.resolved_config.kv_cache_scheme is not None:
for _, module in model.named_modules():
if (
is_attention_module(module)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note if we need to update is_attention_module we can do that in a compressed-tensors PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants