Skip to content

Commit e3779ab

Browse files
committed
v1: Introduce an offloading component
This commit adds a new offloading component, composed of: 1. A scheduler side OffloadingManager (abstract) which kicks-off KV data transfers and keeps track of offloaded data. 2. A worker side OffloadingQueueManager which asynchronously manages KV transfers. Signed-off-by: Or Ozeri <[email protected]>
1 parent afa5b7c commit e3779ab

File tree

5 files changed

+542
-0
lines changed

5 files changed

+542
-0
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ steps:
262262
- pytest -v -s v1/core
263263
- pytest -v -s v1/engine
264264
- pytest -v -s v1/entrypoints
265+
- pytest -v -s v1/offloading
265266
- pytest -v -s v1/sample
266267
- pytest -v -s v1/worker
267268
- pytest -v -s v1/structured_output

tests/v1/offloading/test_worker.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import threading
4+
5+
import pytest
6+
7+
from vllm.v1.offloading.abstract import LoadStoreSpec
8+
from vllm.v1.offloading.worker.worker import (OffloadingQueueManager,
9+
TransferSpec)
10+
11+
12+
class LoadStoreSpec1(LoadStoreSpec):
13+
14+
def __init__(self, success: bool = True, exception: bool = False):
15+
self.called_event = threading.Event()
16+
self.finished_event = threading.Event()
17+
self.success = success
18+
self.exception = exception
19+
20+
@staticmethod
21+
def medium() -> str:
22+
return "1"
23+
24+
def __repr__(self):
25+
return f"{self.medium()}: {id(self)}"
26+
27+
28+
class LoadStoreSpec2(LoadStoreSpec):
29+
30+
@staticmethod
31+
def medium() -> str:
32+
return "2"
33+
34+
def __repr__(self):
35+
return f"{self.medium()}: {id(self)}"
36+
37+
38+
def transfer_function_1_to_2(transfer_spec: TransferSpec) -> bool:
39+
srcs, dsts = transfer_spec
40+
assert len(srcs) == 1
41+
assert len(dsts) == 1
42+
43+
src, dst = srcs[0], dsts[0]
44+
assert isinstance(src, LoadStoreSpec1)
45+
assert isinstance(dst, LoadStoreSpec2)
46+
47+
src.called_event.set()
48+
src.finished_event.wait()
49+
if src.exception:
50+
raise Exception("An expected exception. Don't worry!")
51+
return src.success
52+
53+
54+
def transfer_function_2_to_1(transfer_spec: TransferSpec) -> bool:
55+
srcs, dsts = transfer_spec
56+
assert len(srcs) == 1
57+
assert len(dsts) == 1
58+
59+
src, dst = srcs[0], dsts[0]
60+
assert isinstance(src, LoadStoreSpec2)
61+
assert isinstance(dst, LoadStoreSpec1)
62+
63+
dst.called_event.set()
64+
dst.finished_event.wait()
65+
if dst.exception:
66+
raise Exception()
67+
return dst.success
68+
69+
70+
@pytest.fixture
71+
def offloading_queue_manager():
72+
manager = OffloadingQueueManager()
73+
yield manager
74+
manager.shutdown() # guaranteed cleanup after test, even on failure
75+
76+
77+
def test_offloading_queue_manager(offloading_queue_manager):
78+
"""
79+
Tests OffloadingQueueManager with 2 workers.
80+
One worker performs 1->2 transfers, and the other handles 2->1.
81+
"""
82+
offloading_queue_manager.register_worker(LoadStoreSpec1, LoadStoreSpec2,
83+
transfer_function_1_to_2)
84+
offloading_queue_manager.register_worker(LoadStoreSpec2, LoadStoreSpec1,
85+
transfer_function_2_to_1)
86+
87+
# 1st transfer 1->2 (exception)
88+
src1 = LoadStoreSpec1(exception=True)
89+
dst1 = LoadStoreSpec2()
90+
offloading_queue_manager.transfer_async(1, ([src1], [dst1]))
91+
92+
# 2ed transfer 1->2 (failure)
93+
src2 = LoadStoreSpec1(success=False)
94+
dst2 = LoadStoreSpec2()
95+
offloading_queue_manager.transfer_async(2, ([src2], [dst2]))
96+
97+
# 3rd transfer 1->2 (success)
98+
src3 = LoadStoreSpec1()
99+
dst3 = LoadStoreSpec2()
100+
offloading_queue_manager.transfer_async(3, ([src3], [dst3]))
101+
102+
# 4th transfer 2->1
103+
src4 = LoadStoreSpec2()
104+
dst4 = LoadStoreSpec1()
105+
offloading_queue_manager.transfer_async(4, ([src4], [dst4]))
106+
107+
# 1st transfer started
108+
assert src1.called_event.wait(timeout=1)
109+
110+
# 4th transfer started
111+
assert dst4.called_event.wait(timeout=1)
112+
113+
# 2ed transfer have not started (blocked by 1st)
114+
assert not src2.called_event.is_set()
115+
116+
# no transfer completed yet
117+
assert offloading_queue_manager.get_finished() == []
118+
119+
# complete 1st transfer
120+
src1.finished_event.set()
121+
122+
# 2ed transfer started
123+
src2.called_event.wait(timeout=1)
124+
125+
# 1st transfer finished with failure (exception)
126+
assert offloading_queue_manager.get_finished() == [(1, False)]
127+
128+
# complete 2ed, 3rd and 4th transfers
129+
src2.finished_event.set()
130+
src3.finished_event.set()
131+
dst4.finished_event.set()
132+
133+
# 5th transfer 1->2
134+
src5 = LoadStoreSpec1()
135+
dst5 = LoadStoreSpec2()
136+
offloading_queue_manager.transfer_async(5, ([src5], [dst5]))
137+
138+
# 6th transfer 2->1
139+
src6 = LoadStoreSpec2()
140+
dst6 = LoadStoreSpec1()
141+
offloading_queue_manager.transfer_async(6, ([src6], [dst6]))
142+
143+
# 5th and 6th transfers started
144+
assert src5.called_event.wait(timeout=1)
145+
assert dst6.called_event.wait(timeout=1)
146+
147+
# verify result of 2ed, 3rd and 4th transfers
148+
assert (sorted(offloading_queue_manager.get_finished()) == [(2, False),
149+
(3, True),
150+
(4, True)])
151+
152+
# complete 5th and 6th transfers
153+
src5.finished_event.set()
154+
dst6.finished_event.set()

vllm/v1/offloading/abstract.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
OffloadingManager class for managing KV data offloading in vLLM v1
5+
6+
This class runs in the scheduler, tracks which blocks are offloaded
7+
and their address.
8+
9+
The class provides the following primitives:
10+
lookup() - find the length of the maximal series of blocks,
11+
starting from the first one, that are all offloaded.
12+
parepare_load() - prepare given blocks to be read.
13+
This given blocks will be protected from eviction.
14+
This function returns a LoadSpec which encapsulates
15+
information required for performing the load.
16+
touch() - marks the give blocks as recently used. Can be used
17+
to track block's LRU. This function is separated from the
18+
prepare_load function to allow setting block recency even
19+
for blocks which do not need reading from the cache, such as
20+
blocks that are cached by the GPU prefix cache.
21+
complete_load() - mark blocks which were previously prepared to be
22+
loaded as done loading. This is to re-allow their eviction.
23+
prepare_store() - prepare the given blocks to be written.
24+
Returns a StoreSpec encapsulating offloading information,
25+
as well as a list of blocks that were evicted as a result.
26+
complete_store() - marks a previous store as completed.
27+
Following this call, the given blocks will become loadable.
28+
"""
29+
30+
from abc import ABC, abstractmethod
31+
from collections.abc import Iterable
32+
from dataclasses import dataclass
33+
from typing import Optional
34+
35+
36+
class LoadStoreSpec(ABC):
37+
"""
38+
Abstract metadata that encapsulates information allowing a worker
39+
to load, and optionally also to store, a block of KV data.
40+
"""
41+
42+
@staticmethod
43+
@abstractmethod
44+
def medium() -> str:
45+
"""
46+
Returns a string representation of the medium type
47+
this store/load targets.
48+
"""
49+
pass
50+
51+
52+
@dataclass
53+
class PrepareStoreOutput:
54+
block_hashes_to_store: list[int]
55+
store_specs: list[LoadStoreSpec]
56+
block_hashes_evicted: list[int]
57+
58+
59+
@dataclass
60+
class OffloadingEvent:
61+
block_hashes: list[int]
62+
block_size: int
63+
medium: str
64+
# True if blocks are removed, False if stored
65+
removed: bool
66+
67+
68+
class OffloadingManager(ABC):
69+
70+
@abstractmethod
71+
def lookup(self, block_hashes: list[int]) -> int:
72+
"""
73+
Finds the length of the maximal series of blocks, starting from the
74+
first one, that are all offloaded.
75+
76+
Args:
77+
block_hashes: the hashes identifying the blocks to lookup.
78+
79+
Returns:
80+
An integer representing the maximal number of blocks that
81+
are currently offloaded.
82+
"""
83+
pass
84+
85+
@abstractmethod
86+
def prepare_load(self, block_hashes: list[int]) -> list[LoadStoreSpec]:
87+
"""
88+
Prepare the given blocks to be read.
89+
The given blocks will be protected from eviction until
90+
complete_load is called.
91+
It assumes all given blocks are offloaded.
92+
93+
Args:
94+
block_hashes: the hashes identifying the blocks.
95+
96+
Returns:
97+
A list of LoadStoreSpec, one per each block, that can be used by
98+
a worker to locate and load the actual offloaded KV data.
99+
"""
100+
pass
101+
102+
def touch(self, block_hashes: list[int]):
103+
"""
104+
Mark the given blocks as recently used.
105+
This could in practice mean moving them to the end of an LRU list.
106+
107+
Args:
108+
block_hashes: the hashes identifying the blocks.
109+
"""
110+
return
111+
112+
def complete_load(self, block_hashes: list[int]):
113+
"""
114+
Marks previous blocks that were prepared to load as done loading.
115+
116+
Args:
117+
block_hashes: the hashes identifying the blocks.
118+
"""
119+
return
120+
121+
@abstractmethod
122+
def prepare_store(self,
123+
block_hashes: list[int]) -> Optional[PrepareStoreOutput]:
124+
"""
125+
Prepare the given blocks to be offloaded.
126+
The given blocks will be protected from eviction until
127+
complete_store is called.
128+
129+
Args:
130+
block_hashes: the hashes identifying the blocks.
131+
132+
Returns:
133+
A PrepareStoreOutput indicating which blocks need storing,
134+
where to store them (LoadStoreSpec), and list of blocks that
135+
were evicted as a result.
136+
None is returned if the blocks cannot be stored.
137+
"""
138+
pass
139+
140+
def complete_store(self, block_hashes: list[int], success: bool = True):
141+
"""
142+
Marks blocks which were previously prepared to be stored, as stored.
143+
Following this call, the blocks become loadable.
144+
If if_success is False, blocks that were not marked as stored will be
145+
removed.
146+
147+
Args:
148+
block_hashes: the hashes identifying the blocks.
149+
success: whether the blocks were stored successfully.
150+
"""
151+
return
152+
153+
def take_events(self) -> Iterable[OffloadingEvent]:
154+
"""
155+
Take the offloading events from the manager.
156+
157+
Yields:
158+
New OffloadingEvents collected since the last call.
159+
"""
160+
yield from ()

vllm/v1/offloading/mediums.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from abc import ABC
4+
5+
from vllm.v1.offloading.abstract import LoadStoreSpec
6+
7+
8+
class BlockIDLoadStoreSpec(LoadStoreSpec, ABC):
9+
"""
10+
Spec for loading/storing a KV block from a given block number.
11+
"""
12+
13+
def __init__(self, block_id: int):
14+
self.block_id = block_id
15+
16+
def __repr__(self) -> str:
17+
return str(self.block_id)
18+
19+
20+
class GPULoadStoreSpec(BlockIDLoadStoreSpec):
21+
"""
22+
Spec for loading/storing a KV block to GPU memory.
23+
"""
24+
25+
@staticmethod
26+
def medium() -> str:
27+
return "GPU"
28+
29+
30+
class CPULoadStoreSpec(BlockIDLoadStoreSpec):
31+
"""
32+
Spec for loading/storing a KV block to CPU memory.
33+
"""
34+
35+
@staticmethod
36+
def medium() -> str:
37+
return "CPU"

0 commit comments

Comments
 (0)