Skip to content

[XPU] Add swap_cache_layout op to support Mooncake KV cache for XPU.#7728

Open
Jiajun-Ji wants to merge 2 commits intoPaddlePaddle:developfrom
Jiajun-Ji:xpu-mooncake
Open

[XPU] Add swap_cache_layout op to support Mooncake KV cache for XPU.#7728
Jiajun-Ji wants to merge 2 commits intoPaddlePaddle:developfrom
Jiajun-Ji:xpu-mooncake

Conversation

@Jiajun-Ji
Copy link
Copy Markdown
Contributor

@Jiajun-Ji Jiajun-Ji commented May 7, 2026

Motivation

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

添加swap_cache_layout op以支持mooncake,在_run_read_storage调用交换 xpu kv cache到cpu pinned memory。
mooncake原生不支持XPU,在CPU内存空间下连接mooncake后端。

pip install mooncake-transfer-engine-non-cuda

Modifications

Usage or Command

# ======================== 启动 Mooncake Master ========================
echo "=== Starting Mooncake Master ==="
mkdir -p log_master
nohup mooncake_master \
    --port=${MASTER_PORT} \
    --enable_http_metadata_server=true \
    --http_metadata_server_host=0.0.0.0 \
    --http_metadata_server_port=${METADATA_PORT} \
    --metrics_port=${MASTER_METRICS_PORT} \
    > log_master/nohup 2>&1 &
echo "pid=$!, log -> log_master/nohup"

sleep 2

# ======================== 启动实例 0(XPU 0)========================
echo "=== Starting Instance 0 (XPU 0, port ${S0_PORT}) ==="
export XPU_VISIBLE_DEVICES="0"
export FD_LOG_DIR="log_s0"
mkdir -p ${FD_LOG_DIR}

nohup python -m fastdeploy.entrypoints.openai.api_server \
    --model baidu/${model_name} \
    --port ${S0_PORT} \
    --metrics-port $((S0_PORT+1)) \
    --engine-worker-queue-port $((S0_PORT+2)) \
    --max-model-len 8192 \
    --tensor-parallel-size 1 \
    --max-num-seqs 64 \
    --load-choices default \
    --enable-prefix-caching \
    --kvcache-storage-backend mooncake \
    > ${FD_LOG_DIR}/nohup 2>&1 &
echo "pid=$!, log -> ${FD_LOG_DIR}/nohup"

# ======================== 启动实例 1(XPU 1)========================
echo "=== Starting Instance 1 (XPU 1, port ${S1_PORT}) ==="
export XPU_VISIBLE_DEVICES="1"
export FD_LOG_DIR="log_s1"
mkdir -p ${FD_LOG_DIR}

nohup python -m fastdeploy.entrypoints.openai.api_server \
    --model baidu/${model_name} \
    --port ${S1_PORT} \
    --metrics-port $((S1_PORT+1)) \
    --engine-worker-queue-port $((S1_PORT+2)) \
    --max-model-len 8192 \
    --tensor-parallel-size 1 \
    --max-num-seqs 64 \
    --load-choices default \
    --enable-prefix-caching \
    --kvcache-storage-backend mooncake \
    > ${FD_LOG_DIR}/nohup 2>&1 &
echo "pid=$!, log -> ${FD_LOG_DIR}/nohup"

# ======================== 等待两个实例就绪 ========================
echo "=== Waiting for services to be ready ==="
wait_for_health "${S0_PORT},${S1_PORT}"
echo "All services are ready!"

# ======================== 发送测试请求 ========================
msg="深圳是中国经济实力最强的城市之一。近年来,深圳 GDP 持续稳步增长,2023 年突破 3.4 万亿元人民币,2024 年接近 3.7 万亿元,长期位居全国城市前列。深圳经济以第二产业和第三产业为主,高端制造业、电子信息产业和现代服务业发达,形成了以科技创新为核心的产业结构。依托华为、腾讯、大疆等龙头企业,深圳在数字经济、人工智能、新能源等领域具有显著优势。同时,深圳进出口总额常年位居全国城市第一,是中国对外开放和高质量发展的重要引擎。深圳2024年 GDP 是多少?"

echo ""
echo ">>> Request 1: Instance 0 首次请求(写入 Mooncake 缓存)"
curl -sS -X POST "http://127.0.0.1:${S0_PORT}/v1/chat/completions" \
  -H "Content-Type: application/json" \
  -d "{
    \"messages\": [{\"role\": \"user\", \"content\": \"${msg}\"}],
    \"max_tokens\": 50,
    \"stream\": false,
    \"top_p\": 0
  }" | python3 -m json.tool 2>/dev/null

echo ""
echo ">>> Waiting 5s for Instance 0 to write back KV cache to Mooncake..."
sleep 5

echo ""
echo ">>> Request 2: Instance 1 相同前缀请求(应命中 Mooncake 缓存)"
curl -sS -X POST "http://127.0.0.1:${S1_PORT}/v1/chat/completions" \
  -H "Content-Type: application/json" \
  -d "{
    \"messages\": [{\"role\": \"user\", \"content\": \"${msg}\"}],
    \"max_tokens\": 50,
    \"stream\": false,
    \"top_p\": 0
  }" | python3 -m json.tool 2>/dev/null

Accuracy Tests

config
image
cache_manager.log
image

实例1写入mooncake
image

实例2接受mooncake
image

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings May 7, 2026 03:03
@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented May 7, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 在 XPU 路径新增 swap_cache_layout 自定义算子,用于在 XPU KV cache(按 layer 存放)与 CPU pinned buffer(按 block-major、layer-minor 存放)之间进行布局转换与拷贝,从而支持 Mooncake 作为 KV cache storage backend 的 XPU 场景。

Changes:

  • Mooncake 配置在 CUDA/XPU 平台下默认自动探测并填充 RDMA NICs。
  • cache_manager 的 XPU ops 导出 swap_cache_layout,供存储读写路径使用。
  • 新增 XPU 自定义算子实现 swap_cache_layout 及其对应测试用例。

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py XPU 场景下也支持自动选择 RDMA 设备配置
fastdeploy/cache_manager/ops.py XPU 平台导入并暴露 swap_cache_layout
custom_ops/xpu_ops/src/ops/swap_cache_layout.cc 新增 XPU swap_cache_layout 算子实现(XPU↔CPU pinned buffer 布局转换拷贝)
custom_ops/xpu_ops/test/test_swap_cache_layout.py 新增 swap_cache_layout 的 roundtrip/性能测试用例

Comment on lines +57 to +61
auto* cache_cpu_ptr = reinterpret_cast<T*>(cache_cpu_pointer);

for (int block_idx = 0; block_idx < static_cast<int>(xpu_block_ids.size());
block_idx++) {
auto cur_xpu_block_id = xpu_block_ids[block_idx];
for (int i = 1; i < static_cast<int>(cache_shape.size()); i++) {
cache_block_stride *= cache_shape[i];
}

Comment on lines +121 to +131
mode,
)
paddle.device.synchronize()
cost_time = time.time() - start
print(
f"swap cache layout ({label}), total_gb: {total_gb:.6f}GB, "
f"cost_time: {cost_time:.6f}s, speed: {total_gb / cost_time:.6f}GB/s"
)

def test_performance(self):
for _ in range(3):
@PaddlePaddle-bot
Copy link
Copy Markdown

PaddlePaddle-bot commented May 7, 2026

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-07 11:50:24

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

⚠️ Required 失败任务数:1,等待处理的 Required 任务数:1(运行中)

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
36(0) 36 31 3 1 1 0

2 任务状态汇总

2.1 Required任务 : 8/10 通过

必选任务阻塞合并,失败需优先处理。

状态 任务 耗时 根因 修复建议 日志 重跑
Approval 7s PR问题:缺少 custom op 所需 RD 审批(2项未满足) 请 FastDeploy RD + PaddlePaddle RD 各 1 人 Approve Job -
Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage - 运行中 - Job -
其余 8 个必选任务通过 - - - - -

2.2 可选任务 — 23/26 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Run iluvatar Tests / run_iluvatar_cases 10m27s Job -
Check PR Template 16s Job -
⏸️ CI_HPU - - -
其余 23 个可选任务通过 - - -

3 失败详情(仅 required)

Approval — 代码规范(custom op 审批)(置信度: 高)

Approval

  • 状态: ❌ 失败
  • 错误类型: 代码规范(custom op 审批检查)
  • 置信度: 高
  • 根因摘要: PR问题:缺少 custom op 所需 RD 审批(2个审批未满足)
  • 分析器: 通用分析(fallback)

根因详情:
PR 新增 XPU custom op(swap_cache_layout),触发了 scripts/check_approval.sh 审批检查脚本。该脚本要求:(1) FastDeploy RD 中至少 1 人 Approve(qingqing01/Jiang-Jia-Jun/heavengate);(2) PaddlePaddle RD 中至少 1 人 Approve(jeff41404/yongqiangma)。目前两项均未满足,脚本以 exit code 6 退出。

关键日志:

0. You must have one FastDeploy RD (qingqing01(dangqingqing), Jiang-Jia-Jun(jiangjiajun), heavengate(dengkaipeng)) approval for adding custom op.
1. You must have one PaddlePaddle RD (jeff41404(gaoxiang), yongqiangma(mayongqiang)) approval for adding custom op.
There are 2 approved errors.
##[error]Process completed with exit code 6.

修复建议:

  1. 请 FastDeploy RD(qingqing01、Jiang-Jia-Jun 或 heavengate)在 PR 页面上点击 Approve
  2. 请 PaddlePaddle RD(jeff41404 或 yongqiangma)在 PR 页面上点击 Approve

修复建议摘要: 请 FastDeploy RD + PaddlePaddle RD 各 1 人在 PR 上 Approve

关联变更: PR 标题 [XPU] Add swap_cache_layout op 即为新增 custom op 的变更,触发了审批要求。

链接: 查看日志

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-07 11:13:41

📋 Review 摘要

PR 概述:为 XPU 平台新增 swap_cache_layout custom op,实现 XPU KV cache 与 CPU pinned memory 的 layout 转换,并将 mooncake_store 的 RDMA NIC 自动探测扩展到 XPU 平台

变更范围custom_ops/xpu_ops/src/ops/fastdeploy/cache_manager/

影响面 Tag[XPU] [KVCache] [OP]

📝 PR 规范检查

## Modifications 段落内容为空(仅保留模板占位注释),Checklist 条目均未按实际情况勾选。

标题建议(可直接复制):

  • [XPU][KVCache] Add swap_cache_layout op to support Mooncake KV cache for XPU

PR 描述建议(可直接复制,必须复刻 checklist §D2 模板的完整结构):

## Motivation
添加 swap_cache_layout op 以支持 mooncake,在 _run_read_storage 中调用以在 XPU KV cache 与 CPU pinned memory 之间执行 layout swap。mooncake 原生不支持 XPU,在 CPU 内存空间下连接 mooncake 后端。

## Modifications
- 新增 `custom_ops/xpu_ops/src/ops/swap_cache_layout.cc`:实现 XPU KV cache(layout: `[block_num, head_num, block_size, head_dim]`)与 CPU pinned memory(layout: `[block_num, layer_num, head_num, block_size, head_dim]`)之间的数据搬移 op,支持 mode=0(XPU→CPU)和 mode=1(CPU→XPU)两个方向
- 新增 `custom_ops/xpu_ops/test/test_swap_cache_layout.py`:涵盖 roundtrip 正确性验证和 XPU↔CPU 带宽性能测试
- `fastdeploy/cache_manager/ops.py`:修复 XPU 平台下 `swap_cache_layout` 被错误置为 `None` 的问题,改为从 `fastdeploy.model_executor.ops.xpu` 正确导入
- `fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py`:将 RDMA NIC 自动探测逻辑(`get_rdma_nics()`)扩展到 XPU 平台

## Usage or Command
参见 PR 描述中的 shell 启动脚本(启动 Mooncake Master + 双 XPU 实例 + 验证请求)。

## Accuracy Tests
实例 1 写入 mooncake,实例 2 命中 mooncake 缓存(截图已在 PR 描述中提供)。

## Checklist

- [x] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [x] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.

问题

级别 文件 概述
🟡 建议 custom_ops/xpu_ops/src/ops/swap_cache_layout.cc:74 xpu_memcpy 在 layer×block 双重循环中逐次同步调用,大模型场景下串行 XDMA 调用较多

总体评价

实现思路清晰,修复了 XPU 平台下 swap_cache_layout 被错误置为 None 的遗留问题,功能完整、roundtrip 测试和性能测试覆盖到位。仅 ## Modifications 段落为空,建议补全以便追溯。

void* src = (mode == 0) ? static_cast<void*>(xpu_ptr_now)
: static_cast<void*>(cpu_ptr_now);

int ret = xpu_memcpy(dst, src, cache_block_stride * sizeof(T), copy_kind);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 xpu_memcpylayer_num × block_num 双重循环中逐次同步调用。

对于大模型(32+ 层、多 block 场景),会产生大量串行 XDMA 调用,可能成为吞吐瓶颈。建议评估 XPU runtime 是否支持流式/异步 memcpy 批量提交,或在同一层内批量提交多个 block 的传输请求以提升并发度。

@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 0% with 1 line in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@45350ff). Learn more about missing BASE report.

Files with missing lines Patch % Lines
.../transfer_factory/mooncake_store/mooncake_store.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7728   +/-   ##
==========================================
  Coverage           ?   71.60%           
==========================================
  Files              ?      396           
  Lines              ?    55568           
  Branches           ?     8688           
==========================================
  Hits               ?    39791           
  Misses             ?    13039           
  Partials           ?     2738           
Flag Coverage Δ
GPU 71.60% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants