Skip to content

[Cherry-Pick][Feature] support decode attention for mix(#7688)#7729

Open
lizhenyun01 wants to merge 13 commits intoPaddlePaddle:release/2.6from
lizhenyun01:dec_attn_2.6
Open

[Cherry-Pick][Feature] support decode attention for mix(#7688)#7729
lizhenyun01 wants to merge 13 commits intoPaddlePaddle:release/2.6from
lizhenyun01:dec_attn_2.6

Conversation

@lizhenyun01
Copy link
Copy Markdown
Collaborator

Motivation

C16/静态C8 attention支持,使用方式:flash_attn开启情况下export USE_DECODE_ATTENTION=1

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented May 7, 2026

Thanks for your contribution!

@lizhenyun01 lizhenyun01 changed the title [Feature] support decode attention for mix(#7688) [Cherry-Pick][Feature] support decode attention for mix(#7688) May 7, 2026
@PaddlePaddle-bot
Copy link
Copy Markdown

PaddlePaddle-bot commented May 7, 2026

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-11 12:19:50

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

⚠️ 当前有 1 个 required 任务失败,另有 5 个 required 任务(3 个运行中,2 个等待中)尚未完成,暂不建议合并。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
36(0) 36 26 4 3 3 0

2 任务状态汇总

2.1 Required任务 : 4/10 通过

必选任务阻塞合并,失败需优先处理。

状态 任务 耗时 根因 修复建议 日志 重跑
Approval 8s PR问题:修改受保护路径,缺 4 类 RD 审批 联系对应 RD 完成 GitHub Review Approve Job -
Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage - 运行中 - Job -
Extracted partial CE model tasks to run in CI. / run_ce_cases - 运行中 - Job -
xpu_8cards_case_test / run_xpu_8cards_cases - 运行中 - Job -
⏸️ Run Four Cards Tests / run_4_cards_tests - 等待中 - - -
⏸️ xpu_4cards_case_test / run_xpu_4cards_cases - 等待中 - - -
其余 4 个必选任务通过 - - - - -

2.2 可选任务 — 22/26 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Run iluvatar Tests / run_iluvatar_cases 17m36s Job -
Check PR Template 17s Job -
Trigger Jenkins for PR 52s Job -
⏸️ CI_HPU - - -
其余 22 个可选任务通过 - - -

3 失败详情(仅 required)

Approval — 流程审批(置信度: 高)

Approval

  • 状态: ❌ 失败
  • 错误类型: 流程审批
  • 置信度: 高
  • 根因摘要: PR 修改受保护路径,缺少 4 类 RD 审批
  • 分析器: 通用分析(fallback)

根因详情:

check_approval.sh 脚本检测到该 PR 修改了需要特定 RD 审批的受保护文件路径,共存在 4 项未满足的审批要求:

  1. 修改了自定义 Op 相关代码,需 FastDeploy RD(qingqing01 / Jiang-Jia-Jun / heavengate)之一 Approve
  2. 修改了自定义 Op 相关代码,需 PaddlePaddle RD(jeff41404 / yongqiangma)之一 Approve
  3. 修改了 fastdeploy/spec_decodecustom_ops/gpu_ops/speculate_decoding,需 FastDeploy RD(freeliuzc / Deleter-D)之一 Approve
  4. 修改了 fastdeploy/envs.py,需 FastDeploy RD(Jiang-Jia-Jun / yuanlehome / rainyfly / Wanglongzhi2001)之一 Approve

关键日志:

==> PR title: [Cherry-Pick][Feature] support decode attention for mix(#7688)
0. You must have one FastDeploy RD (qingqing01, Jiang-Jia-Jun, heavengate) approval for adding custom op.
1. You must have one PaddlePaddle RD (jeff41404, yongqiangma) approval for adding custom op.
2. You must have one FastDeploy RD (freeliuzc, Deleter-D) approval for modifing [fastdeploy/spec_decode,custom_ops/gpu_ops/speculate_decoding].
3. You must have one FastDeploy RD (Jiang-Jia-Jun, yuanlehome, rainyfly, Wanglongzhi2001) approval for modifying [fastdeploy/envs.py].
There are 4 approved errors.
##[error]Process completed with exit code 6.

修复建议:

  1. 请 FastDeploy RD(qingqing01 / Jiang-Jia-Jun / heavengate 之一)对本 PR 进行 GitHub Review Approve
  2. 请 PaddlePaddle RD(jeff41404 / yongqiangma 之一)对本 PR 进行 GitHub Review Approve
  3. 请 FastDeploy RD(freeliuzc / Deleter-D 之一)对 spec_decode 相关变更进行 Approve
  4. 请 FastDeploy RD(Jiang-Jia-Jun / yuanlehome / rainyfly / Wanglongzhi2001 之一)对 fastdeploy/envs.py 变更进行 Approve

修复建议摘要: 联系 4 类 RD 完成 GitHub Review Approve

链接: 查看日志

PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-11 12:04:45

📋 Review 摘要

PR 概述:新增 C16(FP16/BF16 KV cache)和静态 C8(INT8 KV cache)decode attention 支持,通过环境变量 USE_DECODE_ATTENTION=1(需同时开启 flash_attn)启用。

变更范围custom_ops/gpu_ops/append_attention/layers/attention/worker/gpu_model_runner.pyspec_decode/mtp.py

影响面 Tag[OP] [Feature] [Speculative Decoding]


📝 PR 规范检查

PR 标题格式符合 Cherry-Pick 规范([Cherry-Pick][Feature] ... (#7688))。但 ## Modifications## Usage or Command## Accuracy Tests 三个段落均为空(仅占位符注释),Checklist 全部未勾选,描述不合规。

PR 描述建议(可直接复制):

## Motivation

C16(FP16/BF16 KV cache)和静态 C8(INT8 KV cache)decode attention 支持,在 flash_attn 开启情况下通过 `export USE_DECODE_ATTENTION=1` 启用,以降低访存压力并提升 decode 阶段吞吐。

## Modifications

- 新增 `custom_ops/gpu_ops/append_attention/` 目录,包含以下新 CUDA 实现:
  - `decode_append_attention_c16_impl.cuh`:FP16/BF16 KV cache 的 decode attention kernel
  - `decode_append_attention_c8_impl.cuh`:静态 INT8 KV cache 的 decode attention kernel
  - `attention_func.cuh``mma_tensor_op.cuh``utils.cuh``mem_util.cuh`:共用 CUDA 工具函数
  - `config_for_attention.cu`:新增 `config_for_attention` Op(含 `GetMaxLenKernel`- 更新 `custom_ops/setup_ops.py`,将新 CUDA 源文件加入编译列表
- 新增 `fastdeploy/envs.py``USE_DECODE_ATTENTION` 环境变量
- 更新 `fastdeploy/model_executor/layers/attention/ops/` Python 绑定(`config_for_attention.py``decode_append_attention.py``decoder_write_cache_with_rope.py`- 更新 `append_attn_backend.py``flash_attn_backend.py`,集成新的 decode attention 路径
- 更新 `gpu_model_runner.py``metax_model_runner.py``input_batch.py``spec_decode/mtp.py`
- 新增测试:`tests/operators/attention/test_decode_append_attention_c16.py``benchmark_decode_attention.py`

## Usage or Command

```bash
# flash_attn 已开启时,通过环境变量启用 decode attention
export USE_DECODE_ATTENTION=1
```

## Accuracy Tests

N/A

## Checklist

- [x] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [x] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.

问题

级别 文件 概述
📝 PR 规范 Modifications / Usage or Command / Accuracy Tests 段落为空,Checklist 全未勾选
❓ 疑问 custom_ops/gpu_ops/append_attention/cu_tensor_map.cuh:22 文件已加入仓库但未被任何 kernel 引用(c8 impl 中对应 include 已注释掉,c16 impl 也未包含),内含 SM90+ 专属 TMA API
❓ 疑问 custom_ops/gpu_ops/append_attention/decode_append_attention_c8_impl.cuh:15 // #include "cu_tensor_map.cuh" 为注释行,建议明确处理
🟡 建议 fastdeploy/worker/gpu_model_runner.py gpu_model_runner.py 已更新,metax_model_runner.py 已同步,但 dcu_model_runner.pyiluvatar_model_runner.py 未在变更列表中——请确认其他 GPU-adjacent 硬件是否需要同步此 feature

总体评价

PR 整体思路清晰,C16/C8 decode attention kernel 实现完整,并附有测试文件。主要待确认点:cu_tensor_map.cuh 是否为遗留文件(TMA 功能尚未接入),以及其他 GPU-adjacent 硬件 runner 是否需要同步更新。

#include <cuda/barrier>
#include <stdexcept>

using barrier = cuda::barrier<cuda::thread_scope_block>;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 cu_tensor_map.cuh 已加入仓库,但当前没有被任何 kernel 引用——decode_append_attention_c8_impl.cuh 第15行中对应的 #include "cu_tensor_map.cuh" 已注释掉,decode_append_attention_c16_impl.cuh 也未包含此头文件。

该文件内的 cuda::device::experimental 命名空间(Hopper TMA API)和 CUtensorMapDataType 均为 SM90+ 专属能力。请确认:

  1. 若此文件是为后续 TMA 优化预留,建议在注释中注明 TODO: SM90+
  2. 若暂不使用,建议从本次 PR 中移除,待真正使用时再合入,避免未来误引入引发编译失败。

// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "utils.cuh"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 // #include "cu_tensor_map.cuh" 已注释掉。若 C8 kernel 最终不需要 TMA,建议删除此注释行以保持代码整洁;若后续要用,建议附上 SM90 gate(#if __CUDA_ARCH__ >= 900)。

@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 64.36782% with 31 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (release/2.6@66dea60). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...l_executor/layers/attention/append_attn_backend.py 0.00% 10 Missing and 1 partial ⚠️
...el_executor/layers/attention/flash_attn_backend.py 46.66% 6 Missing and 2 partials ⚠️
fastdeploy/spec_decode/mtp.py 42.85% 6 Missing and 2 partials ⚠️
...cutor/layers/attention/ops/config_for_attention.py 85.71% 0 Missing and 1 partial ⚠️
...or/layers/attention/ops/decode_append_attention.py 88.88% 0 Missing and 1 partial ⚠️
...ers/attention/ops/decoder_write_cache_with_rope.py 88.88% 0 Missing and 1 partial ⚠️
fastdeploy/worker/gpu_model_runner.py 85.71% 0 Missing and 1 partial ⚠️
Additional details and impacted files
@@              Coverage Diff               @@
##             release/2.6    #7729   +/-   ##
==============================================
  Coverage               ?   71.82%           
==============================================
  Files                  ?      381           
  Lines                  ?    54014           
  Branches               ?     8444           
==============================================
  Hits                   ?    38795           
  Misses                 ?    12445           
  Partials               ?     2774           
Flag Coverage Δ
GPU 71.82% <64.36%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants