Skip to content

[ATTN] mix batch perf tuning#218

Open
YizhouZ wants to merge 3 commits intovllm-project:mainfrom
YizhouZ:yizhou/mix_prefill_decode
Open

[ATTN] mix batch perf tuning#218
YizhouZ wants to merge 3 commits intovllm-project:mainfrom
YizhouZ:yizhou/mix_prefill_decode

Conversation

@YizhouZ
Copy link
Collaborator

@YizhouZ YizhouZ commented Mar 24, 2026

No description provided.

YizhouZ added 2 commits March 24, 2026 01:14
Create is_prefill bool mask from cu_seqlens_q to identify per-batch
type (prefill vs decode). Both chunk_prefill and paged_decode kernels
receive all q/k/v/o data but skip non-matching batches using the mask:
- chunk_prefill skips decode batches (is_prefill[idx_b] == false)
- paged_decode skips prefill batches (is_prefill[idx_b] == true)

This handles mixed batches where most sequences are decoding with a
few prefills, without needing to split/reorder tensors.

Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Copilot AI review requested due to automatic review settings March 24, 2026 08:28
@YizhouZ YizhouZ force-pushed the yizhou/mix_prefill_decode branch from 5e07239 to 5f6a64a Compare March 24, 2026 08:31
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a per-batch is_prefill mask to the XE2 XPU attention path to better handle mixed batches (some sequences in prefill, some in decode) by letting each kernel skip batches that don’t apply.

Changes:

  • Thread a new optional is_prefill mask through XPU attention interfaces into XE2 cutlass kernel params.
  • Update chunk-prefill and paged-decode kernels to conditionally skip batches based on is_prefill[idx_b].
  • Update XPU flash-attn varlen forward dispatch to build the mask for paged mode and invoke both kernels.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
csrc/xpu/attn/xe_2/paged_decode_xe2.h Add is_prefill parameter to XE2 paged-decode entrypoint.
csrc/xpu/attn/xe_2/paged_decode_xe2.cpp Forward is_prefill to impl and into kernel args.
csrc/xpu/attn/xe_2/paged_decode_utils.hpp Update impl declaration to accept is_prefill.
csrc/xpu/attn/xe_2/paged_decode.hpp Add is_prefill pointer into paged_decode_args_t and launcher argument packing.
csrc/xpu/attn/xe_2/kernel/paged_decode_kernel.hpp Skip prefill batches inside paged-decode and reduce kernels using the mask.
csrc/xpu/attn/xe_2/kernel/chunk_prefill_kernel.hpp Skip decode batches inside chunk-prefill kernel using the mask.
csrc/xpu/attn/xe_2/fmha_xe2.h Add is_prefill parameter to XE2 chunk-prefill entrypoint.
csrc/xpu/attn/xe_2/fmha_xe2.cpp Forward is_prefill to impl and into kernel args.
csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp Update impl declaration to accept is_prefill.
csrc/xpu/attn/xe_2/chunk_prefill.hpp Add is_prefill pointer into chunk_prefill_args_t and launcher argument packing.
csrc/xpu/attn/attn_interface.h Extend XPU cutlass attention interfaces with is_prefill.
csrc/xpu/attn/attn_interface.cpp Forward is_prefill through the interface dispatch to XE2 implementations.
csrc/flash_attn/flash_api.cpp Build is_prefill mask for paged mode and dispatch both kernels.
Comments suppressed due to low confidence (1)

csrc/xpu/attn/xe_2/paged_decode_xe2.cpp:180

  • is_prefill is passed to the kernel as a raw pointer (is_prefill.value().data_ptr()) without runtime validation. Because the device kernels do p.is_prefill[idx_b], an incorrect mask tensor (wrong dtype, wrong device, non-contiguous, or numel() != batch_size) can cause invalid device memory reads. Please add TORCH_CHECK validation for dtype kBool, contiguity, numel()==batch_size, and device match before setting args.is_prefill.
      static_cast<float>(sm_scale),
      is_sink ? sm_sink_.value().data_ptr() : nullptr,
      batch_size,
      num_heads_q,
      num_heads_kv,

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 146 to +150
is_paged, // paged
is_causal,
is_local,
is_sink};
is_sink,
is_prefill.has_value() ? is_prefill.value().data_ptr() : nullptr};
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is_prefill is provided, the code forwards is_prefill.value().data_ptr() straight into chunk_prefill_args_t without validating the tensor. Since the kernels index this pointer by idx_b, a wrong dtype/device/shape (e.g., not kBool, not on the same XPU device, or numel() != batch_size) can lead to out-of-bounds reads or device faults. Please add TORCH_CHECKs to enforce is_prefill->scalar_type() == at::kBool, is_prefill->is_contiguous(), is_prefill->numel() == batch_size, and that it’s on the same device as query (and optionally stride(0)==1).

Copilot uses AI. Check for mistakes.
Comment on lines 183 to 245
@@ -187,31 +222,28 @@ std::vector<at::Tensor> mha_varlen_fwd(
is_local ? std::min(max_seqlen_k, eff_window_left + 1) : max_seqlen_k;

int num_tokens = q.size(0);
int batch_size = static_cast<int>(cu_seqlens_q.size(0)) - 1;
int num_heads_q = q.size(1);
int head_dim = q.size(2);
int num_heads_kv = k.size(2);
int block_size = k.size(1);
int kv_block_size = k.size(1);

int num_kv_splits = num_splits.value_or(get_num_splits(
queue, batch_size, num_heads_kv, effective_seqlen_k, block_size));
queue, batch_size, num_heads_kv, effective_seqlen_k, kv_block_size));

at::Tensor tmp_out =
num_kv_splits == 1
? out
: at::empty(
{num_tokens, num_heads_q * num_kv_splits, head_dim},
q.options().device(q.device()));
at::Tensor max_logits = at::full(
at::Tensor decode_max_logits = at::full(
{num_tokens, num_heads_q, num_kv_splits},
-std::numeric_limits<float>::infinity(),
q.options().dtype(at::kFloat).device(q.device()));
at::Tensor exp_sums = at::zeros(
at::Tensor decode_exp_sums = at::zeros(
{num_tokens, num_heads_q, num_kv_splits},
q.options().dtype(at::kFloat).device(q.device()));
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the paged path, the code now always builds an is_prefill_mask and dispatches both chunk-prefill and paged-decode. This means decode-side temporary buffers (tmp_out, decode_max_logits, decode_exp_sums) get allocated and the decode kernel gets launched even when there are no decode batches (e.g., all seq_len_q > 1), which can add significant overhead and memory pressure. Consider adding a fast-path to skip the unused kernel + allocations (e.g., skip chunk-prefill when q.size(0) == batch_size / all decode; and skip paged-decode when the mask indicates all prefill, using a cheap predicate or a guarded reduction).

Copilot uses AI. Check for mistakes.
Comment on lines 255 to +259
tile_scheduler.get_block_coord(); // (Q,V,h,b,id_split)

// Skip prefill batches when is_prefill mask is provided
if (p.is_prefill != nullptr && p.is_prefill[idx_b]) continue;

Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The per-tile continue skip (if (p.is_prefill != nullptr && p.is_prefill[idx_b]) continue;) prevents compute, but the kernel still iterates over all tiles for prefill batches via tile_scheduler. For long prefill sequences this can be a large fixed overhead even though work is skipped. If this is performance-critical, consider incorporating the mask into the scheduler/grid construction (e.g., build a list/range of decode batches and only schedule those) so skipped batches don’t enumerate tiles at all.

Copilot uses AI. Check for mistakes.
Comment on lines 252 to +256
tile_scheduler.get_block_coord(); // (Q,V,h,b)

// Skip decode batches when is_prefill mask is provided
if (p.is_prefill != nullptr && !p.is_prefill[idx_b]) continue;

Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel skips non-prefill batches with if (p.is_prefill != nullptr && !p.is_prefill[idx_b]) continue;, but it still enumerates tiles for those batches through tile_scheduler. For decode-heavy steps this adds overhead even though most work is skipped. If mixed batches are common, consider moving the batch-type filtering into the scheduling stage (so the grid only includes the relevant batches) to avoid iterating over tiles that will always continue.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants