[Feature] Support MLA chunk-prefill & prefix cache for all MLA Architecture Models#7727
[Feature] Support MLA chunk-prefill & prefix cache for all MLA Architecture Models#7727chang-wenbin wants to merge 9 commits intoPaddlePaddle:developfrom
Conversation
|
chang-wenbin seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account. You have signed the CLA already but the status is still pending? Let us recheck it. |
|
Thanks for your contribution! |
CI报告基于以下代码生成(30分钟更新一次): 1 任务总览当前有 1 个 Required 任务失败(
2 任务状态汇总2.1 Required任务 : 1/4 通过
2.2 可选任务 — 14/19 通过
3 失败详情(仅 required)Approval — 代码规范(置信度: 高)Approval
根因详情: 关键日志: 修复建议:
修复建议摘要: 请求xyxinyang或zyyzghb审批此PR 链接: 查看日志 |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #7727 +/- ##
==========================================
Coverage ? 71.57%
==========================================
Files ? 396
Lines ? 55661
Branches ? 8698
==========================================
Hits ? 39837
Misses ? 13086
Partials ? 2738
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
该 PR 旨在让 DeepSeek-V3 的 MLA attention 在 prefix cache + chunked prefill 组合场景下保持正确性:修复 position_ids 的起点/offset 计算,并在 prefill/mixed 的 FlashAttention 调用中把 cached KV 纳入 cu_seqlens_k/max_seqlen_k,同时补齐从 paged latent cache 读取并与新 token latent 交错的路径。
Changes:
- 修复
get_position_ids_and_mask_encoder_batch在 chunked prefill 场景下 offset 叠加导致的 position_ids 错误,并引入“cached_len 作为 position 起点”的写入逻辑。 - 在
MLAAttentionBackend中新增 fused read-cache + interleave(naive/Triton)并扩展MLAAttentionMetadata,prefill/mixed 使用带 cache 的cu_seqlens_k_with_cache与max_total_kv_len调用 FlashAttention。 - DeepSeek-V3 prefill 分支读取 cached latent 并与新 token latent interleave 后再做 KV projection,同时调整 key tensor 的 shape 构造以适配全量 KV token。
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| fastdeploy/model_executor/models/deepseek_v3.py | prefill 分支接入 prefix cache:从 paged latent cache 读取并与新 token latent 交错后再做 KV projection;调整 key shape/赋值逻辑 |
| fastdeploy/model_executor/layers/attention/mla_attention_backend.py | 增加 fused read-cache+interleave(naive/Triton)与 prefix-cache 元数据;prefill/mixed 的 FlashAttention 使用包含 cache 的 seqlens/maxlen;替换 print 为 logger |
| custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu | 修复 chunked prefill + prefix cache 下 position_ids 的 offset 与起点计算逻辑 |
| if need_do_prefill: # max_enc_len_this_time | ||
| key_value = self.kv_b_proj(compressed_kv) | ||
| # Check for prefix cache | ||
| attn_meta = forward_meta.attn_backend.attention_metadata if hasattr(forward_meta, "attn_backend") else None | ||
| has_prefix_cache = False | ||
| total_cached_tokens = 0 | ||
|
|
||
| if attn_meta is not None and isinstance(attn_meta, MLAAttentionMetadata): | ||
| has_prefix_cache = attn_meta.has_prefix_cache | ||
| total_cached_tokens = attn_meta.total_cached_kv_tokens |
| if has_prefix_cache and total_cached_tokens > 0: | ||
| layer_id = self.mla_attn.layer_id if hasattr(self.mla_attn, "layer_id") else 0 | ||
| latent_cache = forward_meta.caches[layer_id] if hasattr(forward_meta, "caches") else None | ||
| if latent_cache is not None: | ||
| block_size = self.mla_attn.block_size if hasattr(self.mla_attn, "block_size") else 64 | ||
| full_compressed_kv, full_k_pe = fused_read_cache_and_interleave( | ||
| latent_cache, | ||
| forward_meta.block_tables, | ||
| compressed_kv, | ||
| key_pe.squeeze(1), | ||
| attn_meta.cu_seqlens_cached_kv, | ||
| forward_meta.cu_seqlens_q, | ||
| self.kv_lora_rank, | ||
| self.qk_rope_head_dim, | ||
| block_size, | ||
| ) |
| bsz = cu_seqlens_cached_kv.shape[0] - 1 | ||
| cu_cached = cu_seqlens_cached_kv.tolist() | ||
| cu_new = cu_seqlens_q.tolist() | ||
| total_cached = int(cu_cached[bsz]) | ||
| total_new = new_compressed_kv.shape[0] | ||
| total_tokens = total_cached + total_new | ||
|
|
||
| full_compressed_kv = paddle.empty([total_tokens, kv_lora_rank], dtype=new_compressed_kv.dtype) | ||
| full_k_pe = paddle.empty([total_tokens, qk_rope_head_dim], dtype=new_k_pe.dtype) | ||
| if total_tokens == 0: | ||
| return full_compressed_kv, full_k_pe | ||
|
|
||
| # block_tables.tolist() is a one-shot D2H; acceptable since host-side loop | ||
| # already requires CPU iteration over total_tokens. | ||
| bt_list = block_tables.tolist() | ||
|
|
||
| is_cached_host = [0] * total_tokens | ||
| src_off_host = [0] * total_tokens | ||
| out_pos = 0 | ||
| for b in range(bsz): | ||
| nc = int(cu_cached[b + 1]) - int(cu_cached[b]) |
| def fused_read_cache_and_interleave(*args, **kwargs): | ||
| """Unified entry. ``FD_MLA_USE_NAIVE=1`` forces the Python reference path.""" | ||
| if os.environ.get("FD_MLA_USE_NAIVE", "0") == "1": |
| // 动态计算当前批次的偏移量。 | ||
| // 每个 batch 只会贡献 encoder_len 或 seq_lens_this_time 中的一个, | ||
| // 而非两者之和(chunked prefill 时 encoder_len > 0 与 decoder_len > 0 | ||
| // 同时成立, | ||
| // 但该 batch 只有 encoder_len 个真实 token)。 | ||
| int offset = 0; | ||
| for (int i = 0; i < tid; i++) { | ||
| offset += seq_lens_encoder[i]; | ||
| if (seq_lens_decoder[i] > 0) { | ||
| if (seq_lens_encoder[i] > 0) { | ||
| offset += seq_lens_encoder[i]; | ||
| } else if (seq_lens_decoder[i] > 0) { | ||
| offset += seq_lens_this_time[i]; | ||
| } | ||
| } |
| """MLA attention forward with prefix cache support.""" | ||
|
|
||
| from fastdeploy.model_executor.layers.attention.mla_attention_backend import ( | ||
| MLAAttentionMetadata, | ||
| fused_read_cache_and_interleave, | ||
| ) |
| cu_total = [0] * (bsz + 1) | ||
| cumsum_cached = 0 | ||
| cumsum_total = 0 | ||
| for i in range(bsz): |
There was a problem hiding this comment.
这里的操作是否可以用自定义算子来处理,取消D2H和H2D 也能简化CPU复杂度
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 Paddle-CI-Agent | pr_review |
2026-05-09 11:37:22
📋 Review 摘要
PR 概述:为 MLA 架构模型新增 chunked prefill + prefix cache 组合支持,修复 position_ids 计算错误和 FlashAttention seqlen 截断问题,新增 Triton fused read+interleave kernel。
变更范围:custom_ops/gpu_ops/、fastdeploy/model_executor/layers/attention/、fastdeploy/model_executor/models/、fastdeploy/model_executor/pre_and_post_process.py
影响面 Tag:[Models] [OP] [KVCache]
📝 PR 规范检查
Usage or Command 和 Accuracy Tests 段落内容为空(仅有注释占位符),Checklist 中部分可勾选项未勾选。
标题建议(可直接复制):
[Feature] Support MLA chunk-prefill & prefix cache for all MLA Architecture Models
(标题格式合规,无需修改)
PR 描述建议(可直接复制,必须复刻 checklist §D2 模板的完整结构):
## Motivation
MLA(Multi-head Latent Attention)原有实现不支持 prefix cache 与 chunked prefill 的组合场景:
1. `get_position_ids_and_mask_encoder_batch.cu` 中 offset 计算在 chunked prefill 时同时叠加 encoder_len + decoder_len,导致 position_ids 错误。
2. `forward_extend` / `forward_mixed` 中 FlashAttention 调用未将 cached KV tokens 纳入 `cu_seqlens_k` 和 `max_seqlen_k`,导致 attention tile 被截断,输出静默损坏。
3. 缺少从 paged latent cache 中读取已缓存 KV 并与新 token KV interleave 的机制。
## Modifications
- `custom_ops/gpu_ops/get_padding_offset.cu`:新增 `seq_lens_decoder` 可选参数,分离计算 `cu_seqlens_q`(仅新 token)与 `cu_seqlens_k`(cached + new),供 FlashAttention 使用。
- `custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu`:修复 chunked prefill 场景下 offset 叠加逻辑,cached 前缀长度正确作为 position 起点。
- `fastdeploy/model_executor/layers/attention/mla_attention_backend.py`:
- 新增 `fused_read_cache_and_interleave_naive`(Python 参考实现)和 `fused_read_cache_and_interleave_triton`(Triton 加速版),统一通过 `fused_read_cache_and_interleave` 入口(环境变量 `FD_MLA_USE_NAIVE=1` 切换)。
- `MLAAttentionMetadata` 增加 `max_seqlen_k` 字段。
- `forward_extend` / `forward_mixed` 中 FlashAttention 调用使用 `max_seqlen_k` 代替 `max_enc_len_this_time`。
- `fastdeploy/model_executor/models/deepseek_v3.py`:prefill 分支读取 cached latent 并与新 token latent interleave 后再做 KV projection,key tensor shape 调整为 `[full_tokens, heads, qk_head_dim]`。
## Usage or Command
N/A
## Accuracy Tests
N/A
## Checklist
- [x] Add at least a tag in the PR title.
- Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
- You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 🟡 建议 | fastdeploy/model_executor/layers/attention/mla_attention_backend.py:129 |
assert 用于运行时校验,python -O 下静默失效 |
| ❓ 疑问 | fastdeploy/model_executor/layers/attention/mla_attention_backend.py:587 |
max_seqlen_k 计算依赖 max_kv_len_this_time 语义需确认 |
总体评价
整体修复方向正确,三处根因定位准确,Triton kernel 设计完整,测试用例覆盖全面。建议作者确认 max_len_tensor_cpu[5] 在 prefill+cache 场景下已正确包含 cached 长度,以排除 max_seqlen_k 低估风险。
| full_k_pe[out_pos] = new_k_pe[new_base + t] | ||
| out_pos += 1 | ||
|
|
||
| assert ( |
There was a problem hiding this comment.
🟡 建议 assert 用于运行时正确性校验,在 python -O 下会静默失效。
建议改为显式异常:
if out_pos != total_tokens:
raise RuntimeError(
f"fused_read_cache_and_interleave_naive: out_pos={out_pos} != total_tokens={total_tokens}"
)| metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1] | ||
| metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2] | ||
| metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[5] | ||
| metadata.max_seqlen_k = max(metadata.max_kv_len_this_time.item(), metadata.max_enc_len_this_time.item()) |
There was a problem hiding this comment.
❓ 疑问 max_seqlen_k 使用 max(max_kv_len_this_time, max_enc_len_this_time)。请确认 max_len_tensor_cpu[5](max_kv_len_this_time)在有 prefix cache 的 prefill batch 中是否已包含 cached_len + new_len?
若 max_kv_len_this_time 只表示 decode batch 的 KV 长度,则在 chunked prefill + prefix cache 场景下可能低估 max_seqlen_k,导致 FA 的 tile 仍被截断。
Motivation
MLA(Multi-head Latent Attention)原有实现不支持 prefix cache 与 chunked prefill 的组合场景:
get_position_ids_and_mask_encoder_batch.cu中 offset 计算在 chunked prefill 时同时叠加 encoder_len + decoder_len,导致 position_ids 错误。forward_extend/forward_mixed中 FlashAttention 调用未将 cached KV tokens 纳入cu_seqlens_k和max_seqlen_k,导致 attention tile 被截断,输出静默损坏。Modifications
custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu:修复 chunked prefill 场景下 offset 叠加逻辑,cached 前缀长度正确作为 position 起点。fastdeploy/model_executor/layers/attention/mla_attention_backend.py:fused_read_cache_and_interleave_naive(Python 参考实现)和fused_read_cache_and_interleave_triton(Triton 加速版),统一通过fused_read_cache_and_interleave入口(环境变量FD_MLA_USE_NAIVE=1切换)。MLAAttentionMetadata增加 prefix cache 相关字段(has_prefix_cache、cu_seqlens_cached_kv、cu_seqlens_k_with_cache、max_total_kv_len等)。init_attention_metadata中新增 prefix cache 元数据计算逻辑。forward_extend/forward_mixed中 FlashAttention 调用使用带 cache 的cu_seqlens_k_with_cache和max_total_kv_len。fastdeploy/model_executor/models/deepseek_v3.py:prefill 分支读取 cached latent 并与新 token latent interleave 后再做 KV projection,key tensor shape 调整为[full_tokens, heads, qk_head_dim]。Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.