Conversation
Create is_prefill bool mask from cu_seqlens_q to identify per-batch type (prefill vs decode). Both chunk_prefill and paged_decode kernels receive all q/k/v/o data but skip non-matching batches using the mask: - chunk_prefill skips decode batches (is_prefill[idx_b] == false) - paged_decode skips prefill batches (is_prefill[idx_b] == true) This handles mixed batches where most sequences are decoding with a few prefills, without needing to split/reorder tensors. Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
5e07239 to
5f6a64a
Compare
There was a problem hiding this comment.
Pull request overview
Adds a per-batch is_prefill mask to the XE2 XPU attention path to better handle mixed batches (some sequences in prefill, some in decode) by letting each kernel skip batches that don’t apply.
Changes:
- Thread a new optional
is_prefillmask through XPU attention interfaces into XE2 cutlass kernel params. - Update chunk-prefill and paged-decode kernels to conditionally skip batches based on
is_prefill[idx_b]. - Update XPU flash-attn varlen forward dispatch to build the mask for paged mode and invoke both kernels.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| csrc/xpu/attn/xe_2/paged_decode_xe2.h | Add is_prefill parameter to XE2 paged-decode entrypoint. |
| csrc/xpu/attn/xe_2/paged_decode_xe2.cpp | Forward is_prefill to impl and into kernel args. |
| csrc/xpu/attn/xe_2/paged_decode_utils.hpp | Update impl declaration to accept is_prefill. |
| csrc/xpu/attn/xe_2/paged_decode.hpp | Add is_prefill pointer into paged_decode_args_t and launcher argument packing. |
| csrc/xpu/attn/xe_2/kernel/paged_decode_kernel.hpp | Skip prefill batches inside paged-decode and reduce kernels using the mask. |
| csrc/xpu/attn/xe_2/kernel/chunk_prefill_kernel.hpp | Skip decode batches inside chunk-prefill kernel using the mask. |
| csrc/xpu/attn/xe_2/fmha_xe2.h | Add is_prefill parameter to XE2 chunk-prefill entrypoint. |
| csrc/xpu/attn/xe_2/fmha_xe2.cpp | Forward is_prefill to impl and into kernel args. |
| csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp | Update impl declaration to accept is_prefill. |
| csrc/xpu/attn/xe_2/chunk_prefill.hpp | Add is_prefill pointer into chunk_prefill_args_t and launcher argument packing. |
| csrc/xpu/attn/attn_interface.h | Extend XPU cutlass attention interfaces with is_prefill. |
| csrc/xpu/attn/attn_interface.cpp | Forward is_prefill through the interface dispatch to XE2 implementations. |
| csrc/flash_attn/flash_api.cpp | Build is_prefill mask for paged mode and dispatch both kernels. |
Comments suppressed due to low confidence (1)
csrc/xpu/attn/xe_2/paged_decode_xe2.cpp:180
is_prefillis passed to the kernel as a raw pointer (is_prefill.value().data_ptr()) without runtime validation. Because the device kernels dop.is_prefill[idx_b], an incorrect mask tensor (wrong dtype, wrong device, non-contiguous, ornumel() != batch_size) can cause invalid device memory reads. Please addTORCH_CHECKvalidation for dtypekBool, contiguity,numel()==batch_size, and device match before settingargs.is_prefill.
static_cast<float>(sm_scale),
is_sink ? sm_sink_.value().data_ptr() : nullptr,
batch_size,
num_heads_q,
num_heads_kv,
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| is_paged, // paged | ||
| is_causal, | ||
| is_local, | ||
| is_sink}; | ||
| is_sink, | ||
| is_prefill.has_value() ? is_prefill.value().data_ptr() : nullptr}; |
There was a problem hiding this comment.
When is_prefill is provided, the code forwards is_prefill.value().data_ptr() straight into chunk_prefill_args_t without validating the tensor. Since the kernels index this pointer by idx_b, a wrong dtype/device/shape (e.g., not kBool, not on the same XPU device, or numel() != batch_size) can lead to out-of-bounds reads or device faults. Please add TORCH_CHECKs to enforce is_prefill->scalar_type() == at::kBool, is_prefill->is_contiguous(), is_prefill->numel() == batch_size, and that it’s on the same device as query (and optionally stride(0)==1).
| @@ -187,31 +222,28 @@ std::vector<at::Tensor> mha_varlen_fwd( | |||
| is_local ? std::min(max_seqlen_k, eff_window_left + 1) : max_seqlen_k; | |||
|
|
|||
| int num_tokens = q.size(0); | |||
| int batch_size = static_cast<int>(cu_seqlens_q.size(0)) - 1; | |||
| int num_heads_q = q.size(1); | |||
| int head_dim = q.size(2); | |||
| int num_heads_kv = k.size(2); | |||
| int block_size = k.size(1); | |||
| int kv_block_size = k.size(1); | |||
|
|
|||
| int num_kv_splits = num_splits.value_or(get_num_splits( | |||
| queue, batch_size, num_heads_kv, effective_seqlen_k, block_size)); | |||
| queue, batch_size, num_heads_kv, effective_seqlen_k, kv_block_size)); | |||
|
|
|||
| at::Tensor tmp_out = | |||
| num_kv_splits == 1 | |||
| ? out | |||
| : at::empty( | |||
| {num_tokens, num_heads_q * num_kv_splits, head_dim}, | |||
| q.options().device(q.device())); | |||
| at::Tensor max_logits = at::full( | |||
| at::Tensor decode_max_logits = at::full( | |||
| {num_tokens, num_heads_q, num_kv_splits}, | |||
| -std::numeric_limits<float>::infinity(), | |||
| q.options().dtype(at::kFloat).device(q.device())); | |||
| at::Tensor exp_sums = at::zeros( | |||
| at::Tensor decode_exp_sums = at::zeros( | |||
| {num_tokens, num_heads_q, num_kv_splits}, | |||
| q.options().dtype(at::kFloat).device(q.device())); | |||
There was a problem hiding this comment.
In the paged path, the code now always builds an is_prefill_mask and dispatches both chunk-prefill and paged-decode. This means decode-side temporary buffers (tmp_out, decode_max_logits, decode_exp_sums) get allocated and the decode kernel gets launched even when there are no decode batches (e.g., all seq_len_q > 1), which can add significant overhead and memory pressure. Consider adding a fast-path to skip the unused kernel + allocations (e.g., skip chunk-prefill when q.size(0) == batch_size / all decode; and skip paged-decode when the mask indicates all prefill, using a cheap predicate or a guarded reduction).
| tile_scheduler.get_block_coord(); // (Q,V,h,b,id_split) | ||
|
|
||
| // Skip prefill batches when is_prefill mask is provided | ||
| if (p.is_prefill != nullptr && p.is_prefill[idx_b]) continue; | ||
|
|
There was a problem hiding this comment.
The per-tile continue skip (if (p.is_prefill != nullptr && p.is_prefill[idx_b]) continue;) prevents compute, but the kernel still iterates over all tiles for prefill batches via tile_scheduler. For long prefill sequences this can be a large fixed overhead even though work is skipped. If this is performance-critical, consider incorporating the mask into the scheduler/grid construction (e.g., build a list/range of decode batches and only schedule those) so skipped batches don’t enumerate tiles at all.
| tile_scheduler.get_block_coord(); // (Q,V,h,b) | ||
|
|
||
| // Skip decode batches when is_prefill mask is provided | ||
| if (p.is_prefill != nullptr && !p.is_prefill[idx_b]) continue; | ||
|
|
There was a problem hiding this comment.
The kernel skips non-prefill batches with if (p.is_prefill != nullptr && !p.is_prefill[idx_b]) continue;, but it still enumerates tiles for those batches through tile_scheduler. For decode-heavy steps this adds overhead even though most work is skipped. If mixed batches are common, consider moving the batch-type filtering into the scheduling stage (so the grid only includes the relevant batches) to avoid iterating over tiles that will always continue.
No description provided.