Skip to content

Comments

[Feature] Add per-request attention capture to the OpenAI-compatible API#35014

Open
Parkprogrammer wants to merge 3 commits intovllm-project:mainfrom
Parkprogrammer:feat/attention-instrumentation-pr
Open

[Feature] Add per-request attention capture to the OpenAI-compatible API#35014
Parkprogrammer wants to merge 3 commits intovllm-project:mainfrom
Parkprogrammer:feat/attention-instrumentation-pr

Conversation

@Parkprogrammer
Copy link

@Parkprogrammer Parkprogrammer commented Feb 21, 2026

Closes #11365

Motivation

Interpretability researchers and multimodal debugging workflows need access to raw attention scores at inference time without patching model code or writing custom inference loops.

This PR exposes per-request attention instrumentation through the existing OpenAI-compatible API. When not requested, there is zero additional compute cost — no forward-pass modification, no graph break, no torch.compile impact.


What This PR Does

Adds an opt-in, per-request attention capture mechanism to vLLM's OpenAI API server. When the server is started with instrumentation enabled and a client explicitly requests capture, Q×Kᵀ softmax attention scores are computed
post-forward, serialised via shared memory, and attached to the chat completion response.


New Server Flags

--enable-attention-instrumentation        # Enable feature (disabled by default)
--attention-instrumentation-layers all   # Default capture layer set (overridable per-request)

New Request Fields (Chat Completions)

Field Type Description
attn_capture int (0/1) Enable capture for this request
attn_capture_layers str Comma-separated layer indices or "all"

New Response Field

attn_capture_data — list of per-layer objects:

[
  {
    "layer_idx": 8,
    "data": "<base64-gzipped tensor>",
    "shape": [T, H, T],
    "dtype": "float16",
    "token_meta": {
      "token_idx_basis": 0,
      "prompt_len": 42,
      "total_len": 50,
      "vision_ranges": [[0, 35]]
    }
  }
]

Shape is [T, H, T] where T = prompt_len + generated_len - 1 (all tokens visible to the attention kernel — the final sampled token has no subsequent Q step). H = query heads.

token_meta.vision_ranges maps image token spans for cross-modal analysis.


Design Notes

Why post-forward computation?

  • Scores are computed after the forward pass completes, leaving model kernels and the attention backend entirely unmodified. No graph break or torch.compile side effect is introduced.

Why Q buffering?

  • Q tensors are captured per-slot during the forward pass and stored in a per-worker buffer. After the request finishes,
    K is read from the KV cache and Q×Kᵀ is computed once. This avoids any impact on the hot path.

Prefix caching interaction

  • When prefix caching is active, prompt tokens are cache-hit and their Q vectors are not recomputed. A persistent q_cache (FIFO-capped at 32 768 slots, ~128 MB at float16/128d/16h) stores Q vectors across requests so prefix-cached tokens are still included in the captured score matrix.

Why shared memory?

  • Attention tensors can be large. Shared memory avoids copying them through the engine's RPC transport layer.
    Each snapshot is keyed by request_id and cleaned up immediately after the output processor reads it.

Concurrency and isolation

  • Each request uses its own shared-memory segment named by request_id, so concurrent requests are fully isolated. Segments are unlinked after reading; a 30-second read timeout is used to handle slow capture paths gracefully.
    If the engine crashes before writing, the segment simply times out and attn_capture_data is absent from the response (no server crash).

Streaming compatibility

  • Attention scores are only available in non-streaming responses (stream=False). Streaming responses do not include attn_capture_data because the snapshot is only available after the full sequence is generated.

Memory Impact

Per captured request (non-streaming only):

T (seq len) H (query heads) Size (float16)
512 16 ~8 MB
2 048 16 ~128 MB
4 096 32 ~1 GB

Tensors are gzip-compressed before serialisation, typically achieving 3–5×
reduction. The persistent Q cache is FIFO-capped at 32 768 slots (~128 MB
worst-case) to bound memory growth across requests.

Capture is disabled by default. No automatic hard cap on response size is
enforced — for very long sequences, callers are expected to limit scope via
attn_capture_layers.


Implementation

Component Change
attn_capture.py (new, ~600 LOC) Q buffering, Q×Kᵀ computation, GQA/MQA support, shared-memory IPC
gpu_model_runner.py Hook into forward pass; call capture() post-request
attention.py Buffer Q tensor to capture hook
output_processor.py Load snapshot from shared memory and attach to RequestOutput
protocol.py Add new request/response fields
arg_utils.py, cache.py Add CLI flags and configuration

attn_capture.py core logic is ~350 lines; the remainder is IPC utilities
and validation helpers. Can be split if preferred.


Example Usage

from openai import OpenAI
import json, base64, gzip, numpy as np

client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY")

raw = client.chat.completions.with_raw_response.create(
    model="Qwen/Qwen2.5-VL-3B-Instruct",
    messages=[{"role": "user", "content": "Hello!"}],
    extra_body={"attn_capture": 1, "attn_capture_layers": "8,16"},
)

resp = json.loads(raw.content)

for item in resp["attn_capture_data"]:
    blob = gzip.decompress(base64.b64decode(item["data"]))
    scores = np.frombuffer(blob, dtype=item["dtype"]).reshape(item["shape"])
    print(f"Layer {item['layer_idx']}: shape={scores.shape}")
    # scores: [T, H, T] — rows=query positions, cols=key positions

See examples/offline_inference/attention_instrumentation/ for full examples
including multimodal token classification and cross-modal attention analysis.


Test Results

Tested on Qwen2.5-VL-3B-Instruct and Gemma-3-4B-IT across 9 input
groups × 6 capture configurations = 54 cases per model, 108 total.
All 108/108 cases passed.

Test script: Parkprogrammer/vllm_capture

Summary:

  • No regression when attn_capture=0 (instrumentation-enabled server, capture disabled per-request)
  • Overhead is a fixed per-request cost; does not scale with sequence length
  • Medium/long prompt groups show near-parity with baseline throughput
Full per-group breakdown

Input groups: text·{short,medium,long}, text+image·{short,medium,long}, image only·{short,medium,long}

Capture configs: off (attn_capture=0), early (layer 2), mid (layer 8), late (layer 15), multi (2,8,15), all

Qwen2.5-VL-3B-Instruct

Group Pass AvgLat AvgTok/s Off Tok/s BaseTok/s AvgPTok
text · short 6/6 0.69s 34.9 35.9 48.0 28
text · medium 6/6 1.16s 41.4 51.2 50.8 76
text · long 6/6 1.19s 40.4 50.3 50.0 432
text+image · short 6/6 0.94s 25.6 10.8 47.1 494
text+image · medium 6/6 1.21s 39.7 48.6 49.6 542
text+image · long 6/6 1.18s 40.6 50.3 49.6 898
image only · short 6/6 0.81s 39.7 50.1 45.4 485
image only · medium 6/6 1.18s 40.8 50.1 45.7 485
image only · long 6/6 1.19s 40.3 48.5 47.3 485

Gemma-3-4B-IT

Group Pass AvgLat AvgTok/s Off Tok/s BaseTok/s AvgPTok
text · short 6/6 0.93s 23.7 27.1 26.8 18
text · medium 6/6 1.89s 25.4 28.6 28.7 64
text · long 6/6 1.88s 25.5 28.7 28.7 423
text+image · short 6/6 1.35s 16.4 17.9 17.6 278
text+image · medium 6/6 2.73s 17.6 18.8 19.0 324
text+image · long 6/6 2.77s 17.4 18.7 18.8 683
image only · short 6/6 1.86s 17.2 18.9 18.7 268
image only · medium 6/6 2.77s 17.3 18.1 19.0 268
image only · long 6/6 2.78s 17.3 18.7 18.7 268

Column legend:

  • AvgTok/s: mean tok/s across all 6 capture configs including all
  • Off Tok/s: attn_capture=0 on instrumentation-enabled server
  • BaseTok/s: plain vLLM server with no --enable-attention-instrumentation flag (2 warm-up requests; indicative, single GPU)

@dosubot
Copy link

dosubot bot commented Feb 21, 2026

Related Documentation

Checked 0 published document(s) in 1 knowledge base(s). No updates required.

How did I do? Any feedback?  Join Discord

@mergify
Copy link

mergify bot commented Feb 21, 2026

Documentation preview: https://vllm--35014.org.readthedocs.build/en/35014/

@mergify mergify bot added documentation Improvements or additions to documentation frontend labels Feb 21, 2026
@mergify
Copy link

mergify bot commented Feb 21, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Parkprogrammer.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

…roject#11365)

- Query buffering at request processing time
- Post-request Q×K^T score computation with causal mask
- Per-layer and layer range selection via CLI flags
- Multimodal token range support (vision/language)
- Shared memory IPC for efficient cross-process transfer
- Integration with GPU model runner and attention layers

Signed-off-by: Jehyun Park <jaheon555@g.skku.edu>
…ect#11365)

- Add attn_capture and attn_capture_layers parameters to chat completions API
- Return attention scores in response via attn_capture_data field
- Support per-request capture control and layer selection
- Add CLI flags: --enable-attention-instrumentation and --attention-instrumentation-layers
- Load attention snapshots from shared memory in output processor
- Include attention data in RequestOutput for client delivery

Signed-off-by: Jehyun Park <jaheon555@g.skku.edu>
…ct#11365)

- Python utilities for extracting and analyzing attention scores
- Multimodal token classification (vision/language/generated)
- Cross-modal attention measurement
- OpenAI SDK and cURL usage examples
- Comprehensive guide with quick start and API reference

Signed-off-by: Jehyun Park <jaheon555@g.skku.edu>
@Parkprogrammer Parkprogrammer force-pushed the feat/attention-instrumentation-pr branch from d9032b6 to 76eaeea Compare February 21, 2026 09:01
@mergify mergify bot removed the needs-rebase label Feb 21, 2026
@mergify
Copy link

mergify bot commented Feb 21, 2026

Hi @Parkprogrammer, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have previously rejected returning hidden states directly in the request output because it adds too much IPC overhead (see #15434 (comment)), the same could be said for this PR. I suggest following a similar approach as #33118

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces a per-request attention capture mechanism for the OpenAI-compatible API. While this is a valuable feature for interpretability and debugging, the current implementation has several critical performance and security issues. The most significant concerns are the introduction of GPU-CPU synchronization points during the forward pass, blocking I/O operations in the main engine loop that can hang the server for up to 30 seconds, and the use of pickle for inter-process communication, which presents a security risk. Additionally, the full attention matrix computation lacks memory safeguards, potentially leading to OOM crashes on long sequences.

Comment on lines +367 to +372
prompt_token_ids = [0] * len(self.prompt_embeds)

# Load attention capture data if capture was requested
attn_capture_data = None
if finished and getattr(self, "attn_capture_enabled", False):
from vllm.model_executor.layers.attention.attn_capture import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to load_attn_snapshot is a blocking operation that polls shared memory for up to 30 seconds with time.sleep. Since this is executed within process_outputs, which runs in the main engine loop (or the output_handler task in AsyncLLM), it will block the entire engine from processing any other requests. This is a critical issue that can cause severe latency spikes or hang the server. The data should be loaded in a non-blocking manner or offloaded to a background thread.

Copy link
Author

@Parkprogrammer Parkprogrammer Feb 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The blocking load_attn_snapshot() call will be removed from output_processor. In the rework, the engine returns only an out-of-band handle (via kv_transfer_params), and no polling happens on the engine loop.

# Detaching the query tensor for buffering. Severing CUDA computation trace
# Query tensor stays in the dict for all requests
try:
query_cpu = query.detach().cpu().clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

query.detach().cpu().clone() is a blocking operation that forces a GPU-CPU synchronization during the model's forward pass. This is called for every attention layer being captured, which will significantly degrade the throughput of the entire batch whenever a capture request is present. Consider keeping the tensors on the GPU and performing the attention computation there, or using asynchronous memory copies to avoid blocking the forward pass.

Copy link
Author

@Parkprogrammer Parkprogrammer Feb 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The query.detach().cpu().clone() path introduces an implicit GPU↔CPU sync. I will replace it with an async D2H copy into pinned memory on a dedicated stream to avoid blocking the forward hot path.

time.sleep(0.01)
continue
try:
data = pickle.loads(bytes(mem.buf[8 : 8 + size]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

Using pickle.loads on data read from shared memory is a security risk, as it can be exploited to execute arbitrary code. Since this is an internal IPC mechanism, it is better to use a safer and more efficient serialization format like msgspec or json.

Copy link
Author

@Parkprogrammer Parkprogrammer Feb 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove pickle from the IPC path and replace it with a safe binary format (fixed struct header + raw tensor bytes). No object deserialization.

"""Capture attention and clean up buffers for a finished request."""
req_state = self.requests.get(req_id)
if not req_state:
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

self.attn_capture.capture(...) is called synchronously within the worker's execution loop for finished requests. This function performs heavy GPU computations (full NxN attention) and CPU-intensive tasks (gzip compression). For long sequences, this will introduce a significant delay before the next batch can start executing, impacting overall engine throughput. This processing should be offloaded to a background thread.

Copy link
Author

@Parkprogrammer Parkprogrammer Feb 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

capture() will be offloaded so the worker execution loop is non-blocking. The worker will enqueue a capture job and proceed; the capture task performs post-forward compute + compression asynchronously and writes the result out-of-band.

return k.to(dtype) # NOTE(jehyun): Uniform dtype for downstream float16 computation


def compute_qk_attention(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

compute_qk_attention computes a full [T, H, T] attention matrix. For long sequences, this can consume a massive amount of GPU memory (e.g., ~8GB for T=8192, H=32 in float32). The current implementation lacks safeguards or limits on the sequence length or the number of layers, making the server vulnerable to OOM crashes via the public API. A hard limit on the total number of elements captured per request should be enforced.

Copy link
Author

@Parkprogrammer Parkprogrammer Feb 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add hard per-request resource caps (e.g., --max-attn-capture-tokens, --max-attn-capture-bytes, and/or max layers). Requests exceeding limits will be rejected with an explicit status instead of attempting allocation.

@Parkprogrammer
Copy link
Author

We have previously rejected returning hidden states directly in the request output because it adds too much IPC overhead (see #15434 (comment)), the same could be said for this PR. I suggest following a similar approach as #33118

Thanks for the pointer to #33118 and the earlier discussion on tensor outputs. I’m going to rework this PR to avoid returning large tensors through the request output / engine IPC:

  • Remove attn_capture_data from RequestOutput; return only an out-of-band handle via kv_transfer_params (KVConnectorBase_V1).
  • Remove blocking/polling from the engine loop (output_processor no longer reads snapshots).
  • Replace pickle with a safe binary IPC format.
  • Add hard per-request caps (tokens/bytes/layers) to prevent OOM/DoS.
  • Replace the forward-pass .cpu() path with async pinned D2H, and make capture() non-blocking w.r.t the worker loop.

Does this direction align with what you had in mind with #33118? If there are additional constraints you want me to follow, I’m happy to incorporate them before pushing the rework.

@DarkLight1337
Copy link
Member

I think this is a relatively niche feature, so I'd prefer to minimize intrusion to our existing code. I am not sure whether we really want to return the tensors via API either, perhaps it's better for the client to connect to the KV cache directly using the returned handle?

cc @WoosukKwon @robertgshaw2-redhat @njhill

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation frontend v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: Add support for attention score output

2 participants