Skip to content

[RFC]: Separated CPU KV Cache Offloading/Transfer Process #22605

@ApostaC

Description

@ApostaC

Motivation.

CPU KV cache offloading and CPU-based KV cache transmission are always very important to vLLM.
There are already some existing solutions, such as KV-connector-based offloading, KV-connector-based KV transmissions.
Quite a few RFCs are also proposing new solutions for CPU KV cache offloading (#16144, #19854).

However, the current implementations and proposals have several inefficiencies and complexity issues:

  1. CPU overhead in workers – memcpy kernel launches and other CPU KV tasks run in worker processes, impacting inference performance.
  2. Extra cross-worker IPC – CPU KV metadata gathering requires additional IPC, adding latency and complexity.
  3. Coupling with inference loop – Transmission logic is only triggered when a worker is invoked, adding unnecessary latency.
  4. Fragmented code paths – CPU offloading, eviction, and CPU-based prefill/disaggregated paths are handled separately, resulting in duplicated h2d/d2h calls.
  5. Scheduler complexity – The scheduler has grown complicated due to various asynchronous handling requirements.

Also, to land such functionality in production, there are a few more points to think about

  1. Fault tolerance -- the error in the KV cache offloading/transmission module should not be propagated to the main vLLM process, nor impact the output correctness.
  2. Full CUDA-graph compatibility -- the offloading operations should be compatible with CUDA graph operations.
  3. Compatibility with the hybrid memory allocator.

Proposed Change.

Instead of embedding the offloading function call into the worker process, this RFC proposes a new design of "separating KV cache offloader/sender/receiver into another process" as shown in the diagram below.

Image

The key technique behind this proposal is PyTorch's UntypedStorage (for raw tensor data access) and CUDA IPC handle (for inter-process data sharing).

Workflow in the new proposal

During the bootstrap, the scheduler will send a handshake message to the "CPU KV process" that has the vLLM config, and workers will also send handshake messages to register the GPU KV cache tensors in the CPU KV process.

During runtime, the scheduler will send the requests/pages that need to offload/onload/send KV cache to the CPU KV process via zmq. The CPU KV process will send back the "finished request IDs" to the scheduler after the corresponding offloading/onloading/send/receive KV operations are finished.

The key difference from all the existing solutions is that the worker won't need to launch any memory-related kernels anymore. Therefore, the overhead inside the worker process can be minimized, and thus the GPU utilization can be maximized.

Key benefits

  1. Minimizing overhead: All memcpy kernel launches are moved out of the workers into the CPU KV Manager process, minimizing worker overhead. Also, a single shared CPU KV memory buffer across all workers, which eliminates cross-worker IPC for KV metadata.
  2. No extra copy: Unifies CPU KV transmission, CPU offloading, and eviction into one module.
  3. Cleaner code structure: Decouples CPU KV transmission from the LLM inference loop.
  4. Fault tolerance: Failures in the CPU KV cache won't impact the main vLLM processes.
  5. CUDA graph compatibility: CUDA graph compatibility (the new process won't impact vLLM's CUDA graph)

Other side benefits (with careful engineering)

  1. For MLA models, only needs to save a single version of KV cache when TP > 1
  2. Simplified code paths for async offloading/eviction, which will clean the code in the scheduler.

Feedback Period.

No response

CC List.

@simon-mo @WoosukKwon @robertgshaw2-redhat @njhill @KuntaiDu

Any Other Things.

Early promise

We use the following 2 workloads to test the worst-case overhead of the offloading (i.e., 0% prefix cache hit)

  1. Llama 3.1 8B Random workload, 1000 input, 100 output, 1000 requests
  2. Llama 3.1 8B Random workload, 8000 input, 100 output, 200 requests

The new prototype implementation (without careful engineering) has already reduced the overhead from ~10% to ~3%.

On workload 1:

Original vLLM existing connector proposed prototype
Mean TPOT (ms) 162.11 181.27 169.16
Token throughput (tokens/sec) 33161.90 29690.18 32013.08

On workload 2:

Original vLLM existing connector proposed prototype
Mean TPOT (ms) 200.33 214.42 209.48
Token throughput (tokens/sec) 30846.83 28738.53 29833.48

Roadmap

Other potential discussions

  1. Layerwise pipelining
    There is a concern that if we separate the CPU KV cache management into another process, it will be hard to do layerwise pipelining. However, layerwise pipelining is less important after we separate the process. This is because (i) the overhead at the work side is already minimized, and the computation and communication are naturally overlapped in the new design, and (ii) for PD-disaggregation, we can do chunk-by-chunk pipelining to achieve similar performance as layerwise pipelining.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions