Skip to content

Conversation

orozery
Copy link
Contributor

@orozery orozery commented Aug 10, 2025

This PR adds an offloading connector that delegates to a generic API introduced in #19848.
The actual implementation of this API is built using a factory which is currently empty.
A follow-up small PR will register a CPU implementation based on #20075 (scheduler-side implementation) and #21448 (worker-side implementation).

Part of RFC #19854.
Depends on PRs #19728, #19848, #19737.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces a new offloading connector. The implementation is extensive and adds a lot of new components. My review found several critical issues that need to be addressed. These include a race condition in the tests, a critical assertion that would crash workers on transfer failures, a resource leak due to unjoined threads, and an incorrect list slicing that would lead to errors. These issues affect both the correctness of the new feature and the reliability of its tests.

@orozery orozery force-pushed the offloading-connector branch 2 times, most recently from 4b24d03 to 4fca175 Compare August 10, 2025 14:49
@KuntaiDu
Copy link
Collaborator

mark, will take a look and review after this PR gets stable.

@orozery orozery force-pushed the offloading-connector branch from 4fca175 to 8d7a0d7 Compare August 11, 2025 13:43
@mergify mergify bot added the documentation Improvements or additions to documentation label Aug 11, 2025
@orozery orozery force-pushed the offloading-connector branch 2 times, most recently from 866a51c to 4872976 Compare August 11, 2025 19:11
This commit adds a new offloading component, composed of:
1. A scheduler side OffloadingManager (abstract) which kicks-off KV data transfers and keeps track of offloaded data.
2. A worker side OffloadingQueueManager which asynchronously manages KV transfers.

Signed-off-by: Or Ozeri <[email protected]>
Copy link

mergify bot commented Aug 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @orozery.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 14, 2025
@orozery orozery force-pushed the offloading-connector branch from 4872976 to 9d2e0b9 Compare August 14, 2025 12:40
@mergify mergify bot removed the needs-rebase label Aug 14, 2025
This commit move the request block hashes from the KVCacheManager
to the Request object itself.
In particular, this will allow connectors to access the request block hashes.

Signed-off-by: Or Ozeri <[email protected]>
This commit adds a new scheduler-side connector API
to collect KV cache events.
Additionally, we add a medium field to KV events, to allow
distinguishing KV events on different mediums
(e.g. blocks stored on cpu, disk, or gpu (default)).

Signed-off-by: Or Ozeri <[email protected]>
This commit introduces a new OffloadingConnector for
offloading blocks of KV data via a generic interface.

Signed-off-by: Or Ozeri <[email protected]>
@orozery orozery force-pushed the offloading-connector branch from 9d2e0b9 to 11e1629 Compare August 14, 2025 12:50
@ApostaC
Copy link
Collaborator

ApostaC commented Aug 19, 2025

@orozery Hey, thanks for the amazing work!

Is there a centralized branch for us to run some benchmarks? We are excited to test it and would like to push for landing this connector-based CPU offloading solution if it has good performance 🚀.

@ZrBac
Copy link

ZrBac commented Aug 22, 2025

@orozery Hey, thanks for the amazing work!

Is there a centralized branch for us to run some benchmarks? We are excited to test it and would like to push for landing this connector-based CPU offloading solution if it has good performance 🚀.

try this branch, https://github.com/orozery/vllm/tree/cpu-offloading-afa5b7

Copy link
Collaborator

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@orozery Thanks for the great effort! Some high-level comments:

  1. The current implementation is a bit over-complicated. We should simplify the transfer_fn and the LoadStoreSpec abstraction in order to get better performance and better maintainability.
  2. There are a few potential performance improvements that we can do (immediately or as potential follow-ups)
    (a) Launch the d2h/h2d copy kernels in a separate cuda stream
    (b) Use cuda events to implement the async loading so that we don't need to launch extra python threads in the worker process.

@dataclass
class PrepareStoreOutput:
block_hashes_to_store: list[int]
store_specs: list[LoadStoreSpec]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the store_specs being a list of Python objects will be pretty heavy.
From other related PRs, I see that we are going to transmit this list between processes and threads, plus doing some for loops over it in the worker process. This can incur a huge amount of python-level overheads.

A proposal is to use torch.Tensor for now, since the BlockIDLoadStoreSpec are just wrapping around integers.

Comment on lines +8 to +37
class BlockIDLoadStoreSpec(LoadStoreSpec, ABC):
"""
Spec for loading/storing a KV block from a given block number.
"""

def __init__(self, block_id: int):
self.block_id = block_id

def __repr__(self) -> str:
return str(self.block_id)


class GPULoadStoreSpec(BlockIDLoadStoreSpec):
"""
Spec for loading/storing a KV block to GPU memory.
"""

@staticmethod
def medium() -> str:
return "GPU"


class CPULoadStoreSpec(BlockIDLoadStoreSpec):
"""
Spec for loading/storing a KV block to CPU memory.
"""

@staticmethod
def medium() -> str:
return "CPU"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The abstraction of such LoadStoreSpec seems to be over-complicated. Why should we have this? Would there be simpler alternatives? (i.e., just use two lists or two tensors for cpu->gpu block ids and gpu->cpu block ids)

Comment on lines +47 to +61
@abstractmethod
def get_transfer_functions(
self, kv_caches: dict[str, torch.Tensor]
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
TransferFunction, int]]:
"""
Get transfer functions along with their respective src and dst types.

Args:
kv_caches: A dictionary of layer_name -> gpu_kv_cache tensor.

Yields:
Tuples of (src_type, dst_type, transfer_function, num_threads).
"""
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure the purpose of having such an abstraction for CPU offloading. The logic is a bit hard to follow here.

Can we directly call swap_blocks in the connector? That would be simple and easy to understand.

Comment on lines +51 to +56
for thread_idx in range(num_threads):
t = threading.Thread(target=self.run,
args=(thread_idx, ),
name=f"{transfer_type}-worker-{thread_idx}")
t.start()
self._worker_threads.append(t)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having another thread in the worker process may incur extra overhead.

At a high level, we might want to use CUDA events to achieve async so that we don't need to create new threads.
IIUC, this could be a longer discussion, and we can gradually push the implementation in.

job_id)

try:
success = self.transfer_fn(transfer_spec)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I saw in other PRs, the transfer_fn and the internal swap_blocks are launched in the same CUDA stream as the LLM inference.
This will make CPU offloading a blocking operation, resulting in a negative performance impact, especially when there is no KV cache reuse.

For performance's sake, we should make sure the d2h and h2d copies are launched in different cuda streams.,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants