Conversation
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
There was a problem hiding this comment.
Pull request overview
This PR focuses on performance tuning for the Xe2 chunk-prefill kernel targeting head_dim=128, alongside adjustments for PagedKV (BNHS) handling and build/tooling updates.
Changes:
- Retunes head_dim=128 tile shapes and rewrites the chunk-prefill mainloop (including a new softmax rescale path).
- Refactors PagedKV K/V tensor shaping/striding and mainloop call signatures to operate on block-structured K/V.
- Updates SYCL build flags, CUTLASS revision, and adds CMake support for excluding kernel sources.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| csrc/xpu/attn/xe_2/kernel/chunk_prefill_kernel.hpp | Adjusts PagedKV tensor shapes/layout selection and mainloop invocation for paged vs non-paged paths. |
| csrc/xpu/attn/xe_2/fmha_xe2.cpp | Updates PagedKV dimension interpretation (BNHS) and head/block size extraction. |
| csrc/xpu/attn/xe_2/fmha_utils.hpp | Retunes head_dim=128 chunk policy tile shapes for performance. |
| csrc/xpu/attn/xe_2/collective/chunk_prefill_mainloop.hpp | Large mainloop rewrite: PagedKV block addressing, prefetch strategy, and softmax rescale refactor. |
| csrc/xpu/attn/xe_2/chunk_prefill.hpp | Updates packed strides and passes additional PagedKV metadata (num_heads_kv/total blocks). |
| csrc/flash_attn/flash_api.cpp | Comments out several decode variables (currently breaking compilation due to remaining uses). |
| cmake/utils.cmake | Adds EXCLUDE_SOURCES argument to kernel library helper. |
| CMakeLists.txt | Updates SYCL flags, SYCL link targets, and CUTLASS revision pin. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
csrc/flash_attn/flash_api.cpp
Outdated
| // int num_tokens = q.size(0); | ||
| // int batch_size = static_cast<int>(cu_seqlens_q.size(0)) - 1; | ||
| // int num_heads_q = q.size(1); | ||
| // int head_dim = q.size(2); | ||
| // int num_heads_kv = k.size(2); | ||
| // int block_size = k.size(1); | ||
|
|
||
| int num_kv_splits = num_splits.value_or(get_num_splits( | ||
| queue, batch_size, num_heads_kv, effective_seqlen_k, block_size)); | ||
| // int num_kv_splits = num_splits.value_or(get_num_splits( | ||
| // queue, batch_size, num_heads_kv, effective_seqlen_k, block_size)); | ||
|
|
||
| at::Tensor tmp_out = | ||
| num_kv_splits == 1 |
There was a problem hiding this comment.
The local variable num_kv_splits is commented out but is still used to size/branch tmp_out (and likely later tensors). This will not compile; either restore num_kv_splits (and any required inputs like effective_seqlen_k) or refactor the downstream logic to not depend on it.
|
|
||
| static constexpr bool Fp8KV = | ||
| is_any_of_v<ElementK, float_e5m2_t, float_e4m3_t>; | ||
| static constexpr bool Fp8KV = false; |
There was a problem hiding this comment.
Hard-coding Fp8KV = false removes the FP8 KV handling that previously depended on the KV element type, which is a functional regression if FP8 KV is supported by this path. Restore the original type-based detection (and keep the scaling logic guarded by if constexpr (Fp8KV)), or explicitly gate this optimization to only the configurations that never use FP8.
| // Perf. knobs | ||
| #define STAGES 2 | ||
| #define SYNC 1 | ||
| #define POSTPROCESSING 1 | ||
| #define PREFETCH_V_EARLIER 1 |
There was a problem hiding this comment.
Defining generic macros (STAGES, SYNC, etc.) in a public header is risky (name collisions across translation units, hard-to-debug build differences). Prefer static constexpr/inline constexpr constants in the relevant struct/namespace (or template parameters), or at least prefix macros (e.g., VLLM_XE2_CHUNK_PREFILL_STAGES) and #undef them at the end of the header.
| ? make_ordered_layout(shape_Q, Step<_2, _0, _1, _3>{}) | ||
| : make_layout(shape_Q, p.dQ); | ||
| auto layout_k = (PagedKV || is_var_len) | ||
| ? make_ordered_layout(shape_K, Step<_2, _0, _1, _3>{}) | ||
| ? make_layout(shape_K, p.dK) | ||
| : make_layout(shape_K, p.dK); | ||
| auto layout_v = (PagedKV || is_var_len) | ||
| ? make_ordered_layout(shape_V, Step<_0, _2, _1, _3>{}) | ||
| ? make_layout(shape_V, p.dV) | ||
| : make_layout(shape_V, p.dV); |
There was a problem hiding this comment.
The ternaries for layout_k and layout_v are now redundant (both branches are identical), which reduces readability and can hide intended layout differences for is_var_len. Either simplify to a single make_layout(...) or reintroduce the distinct var-len/paged behavior if it was required for correctness/perf.
| ? make_ordered_layout(shape_Q, Step<_2, _0, _1, _3>{}) | ||
| : make_layout(shape_Q, p.dQ); | ||
| auto layout_k = (PagedKV || is_var_len) | ||
| ? make_ordered_layout(shape_K, Step<_2, _0, _1, _3>{}) | ||
| ? make_layout(shape_K, p.dK) | ||
| : make_layout(shape_K, p.dK); | ||
| auto layout_v = (PagedKV || is_var_len) | ||
| ? make_ordered_layout(shape_V, Step<_0, _2, _1, _3>{}) | ||
| ? make_layout(shape_V, p.dV) | ||
| : make_layout(shape_V, p.dV); |
There was a problem hiding this comment.
Previously K/V used make_ordered_layout(...) when is_var_len (and also for paged), but now K/V always use make_layout(...) while Q/O still use ordered layouts for var-len. This asymmetry can change the logical indexing for var-len runs and may lead to incorrect addressing if var-len expects the ordered layout steps. If var-len is supported here, please restore the ordered layout path for K/V (or document why it is no longer needed and prove equivalence).
csrc/xpu/attn/xe_2/chunk_prefill.hpp
Outdated
| args.block_size, | ||
| args.max_blocks_per_seq, | ||
| args.total_seqlen_k, | ||
| args.num_heads_k, | ||
| args.is_paged ? args.total_seqlen_k / args.block_size : 0, |
There was a problem hiding this comment.
FMHAFwdMainloop::Arguments now expects num_heads_kv, but the launcher passes args.num_heads_k. If args.num_heads_k is not exactly the KV head count (or is named inconsistently), this is error-prone. Prefer passing the already-computed num_heads_kv (or rename the field to num_heads_kv) to make the contract unambiguous.
csrc/xpu/attn/xe_2/fmha_xe2.cpp
Outdated
| batch_size = query.size(0); | ||
| num_heads_q = query.size(1); | ||
| num_heads_kv = is_paged ? key_cache.size(2) : key_cache.size(1); | ||
| num_heads_kv = is_paged ? key_cache.size(1) : key_cache.size(1); |
There was a problem hiding this comment.
The conditional is now redundant and reads like a copy/paste error. Replace with a single assignment to avoid confusion (and consider adding a brief comment noting why the head dimension is size(1) in both layouts after the BNHS change).
| num_heads_kv = is_paged ? key_cache.size(1) : key_cache.size(1); | |
| // After BNHS layout change, heads are at dimension 1 for both paged and non-paged key_cache. | |
| num_heads_kv = key_cache.size(1); |
| // V head tile offset for output head dimension | ||
| int vv_base = get<1>(blk_qv) * VTiles; | ||
|
|
||
| const int barrier_scope = 2; |
There was a problem hiding this comment.
barrier_scope = 2 is a magic number and makes it hard to reason about correctness if barrier semantics change. Use the existing named constant/enum that was previously used (e.g., ScopeSubgroup or the appropriate CUTLASS/sycl scope constant used by barrier_arrive/barrier_wait) to avoid fragile coupling.
| const int barrier_scope = 2; | |
| enum : int { | |
| kBarrierScopeSubgroup = 2 | |
| }; | |
| const int barrier_scope = kBarrierScopeSubgroup; |
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
|
can you make HND support as a separate PR since it belongs to new feature support? |
Revert BNHS [num_blocks, num_heads, block_size, head_size] layout changes while keeping all performance tuning optimizations: - Tile shape tuning (head128: 128->256) - Barrier arrive/wait synchronization - Lane-based causal + k-remainder masking - Softmax returns rescale, fused with PV gemm - K prefetch with Stages pipeline - V prefetch reordering - Build flags and CUTLASS revision update - Scale broadcast fix for paged decode HND layout changes moved to yizhou/HND_format branch. Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
477e277 to
2d0acf0
Compare
No description provided.