Skip to content

Commit e7b1232

Browse files
committed
v1: Introduce an offloading component
This commit adds a new offloading component, composed of: 1. A scheduler side OffloadingManager (abstract) which kicks-off KV data transfers and keeps track of offloaded data. 2. A worker side OffloadingQueueManager which asynchronously manages KV transfers. Signed-off-by: Or Ozeri <[email protected]>
1 parent 693ab28 commit e7b1232

File tree

5 files changed

+494
-0
lines changed

5 files changed

+494
-0
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ steps:
262262
- pytest -v -s v1/core
263263
- pytest -v -s v1/engine
264264
- pytest -v -s v1/entrypoints
265+
- pytest -v -s v1/offloading
265266
- pytest -v -s v1/sample
266267
- pytest -v -s v1/worker
267268
- pytest -v -s v1/structured_output

tests/v1/offloading/test_worker.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import threading
4+
5+
from vllm.v1.offloading.abstract import LoadStoreSpec
6+
from vllm.v1.offloading.worker.worker import (OffloadingQueueManager,
7+
TransferSpec)
8+
9+
10+
class LoadStoreSpec1(LoadStoreSpec):
11+
12+
def __init__(self, success: bool = True, exception: bool = False):
13+
self.called_event = threading.Event()
14+
self.finished_event = threading.Event()
15+
self.success = success
16+
self.exception = exception
17+
18+
@staticmethod
19+
def medium() -> str:
20+
return "1"
21+
22+
23+
class LoadStoreSpec2(LoadStoreSpec):
24+
25+
@staticmethod
26+
def medium() -> str:
27+
return "2"
28+
29+
30+
def transfer_function_1_to_2(transfer_spec: TransferSpec) -> bool:
31+
srcs, dsts = transfer_spec
32+
assert len(srcs) == 1
33+
assert len(dsts) == 1
34+
35+
src, dst = srcs[0], dsts[0]
36+
assert isinstance(src, LoadStoreSpec1)
37+
assert isinstance(dst, LoadStoreSpec2)
38+
39+
src.called_event.set()
40+
src.finished_event.wait()
41+
if src.exception:
42+
raise Exception("An expected exception. Don't worry!")
43+
return src.success
44+
45+
46+
def transfer_function_2_to_1(transfer_spec: TransferSpec) -> bool:
47+
srcs, dsts = transfer_spec
48+
assert len(srcs) == 1
49+
assert len(dsts) == 1
50+
51+
src, dst = srcs[0], dsts[0]
52+
assert isinstance(src, LoadStoreSpec2)
53+
assert isinstance(dst, LoadStoreSpec1)
54+
55+
dst.called_event.set()
56+
dst.finished_event.wait()
57+
if dst.exception:
58+
raise Exception()
59+
return dst.success
60+
61+
62+
def test_offloading_queue_manager():
63+
"""
64+
Tests OffloadingQueueManager with 2 workers.
65+
One worker performs 1->2 transfers, and the other handles 2->1.
66+
"""
67+
offloading_queue_manager = OffloadingQueueManager()
68+
offloading_queue_manager.register_worker(LoadStoreSpec1, LoadStoreSpec2,
69+
transfer_function_1_to_2)
70+
offloading_queue_manager.register_worker(LoadStoreSpec2, LoadStoreSpec1,
71+
transfer_function_2_to_1)
72+
73+
# 1st transfer 1->2 (exception)
74+
src1 = LoadStoreSpec1(exception=True)
75+
dst1 = LoadStoreSpec2()
76+
offloading_queue_manager.transfer_async(1, ([src1], [dst1]))
77+
78+
# 2ed transfer 1->2 (failure)
79+
src2 = LoadStoreSpec1(success=False)
80+
dst2 = LoadStoreSpec2()
81+
offloading_queue_manager.transfer_async(2, ([src2], [dst2]))
82+
83+
# 3rd transfer 1->2 (success)
84+
src3 = LoadStoreSpec1()
85+
dst3 = LoadStoreSpec2()
86+
offloading_queue_manager.transfer_async(3, ([src3], [dst3]))
87+
88+
# 4th transfer 2->1
89+
src4 = LoadStoreSpec2()
90+
dst4 = LoadStoreSpec1()
91+
offloading_queue_manager.transfer_async(4, ([src4], [dst4]))
92+
93+
# 1st transfer started
94+
assert src1.called_event.wait(timeout=1)
95+
96+
# 4th transfer started
97+
assert dst4.called_event.wait(timeout=1)
98+
99+
# 2ed transfer have not started (blocked by 1st)
100+
assert not src2.called_event.is_set()
101+
102+
# no transfer completed yet
103+
assert offloading_queue_manager.get_finished() == []
104+
105+
# complete 1st transfer
106+
src1.finished_event.set()
107+
108+
# 2ed transfer started
109+
src2.called_event.wait(timeout=1)
110+
111+
# 1st transfer finished with failure (exception)
112+
assert offloading_queue_manager.get_finished() == [(1, False)]
113+
114+
# complete 2ed, 3rd and 4th transfers
115+
src2.finished_event.set()
116+
src3.finished_event.set()
117+
dst4.finished_event.set()
118+
119+
# 5th transfer 1->2
120+
src5 = LoadStoreSpec1()
121+
dst5 = LoadStoreSpec2()
122+
offloading_queue_manager.transfer_async(5, ([src5], [dst5]))
123+
124+
# 6th transfer 2->1
125+
src6 = LoadStoreSpec2()
126+
dst6 = LoadStoreSpec1()
127+
offloading_queue_manager.transfer_async(6, ([src6], [dst6]))
128+
129+
# 5th and 6th transfers started
130+
assert src5.called_event.wait(timeout=1)
131+
assert dst6.called_event.wait(timeout=1)
132+
133+
# verify result of 2ed, 3rd and 4th transfers
134+
assert (sorted(offloading_queue_manager.get_finished()) == [(2, False),
135+
(3, True),
136+
(4, True)])
137+
138+
# complete 5th and 6th transfers
139+
src5.finished_event.set()
140+
dst6.finished_event.set()

vllm/v1/offloading/abstract.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
OffloadingManager class for managing KV data offloading in vLLM v1
5+
6+
This class runs in the scheduler, tracks which blocks are offloaded
7+
and their address.
8+
9+
The class provides the following primitives:
10+
lookup() - find the length of the maximal series of blocks,
11+
starting from the first one, that are all offloaded.
12+
parepare_load() - prepare given blocks to be read.
13+
This given blocks will be protected from eviction.
14+
This function returns a LoadSpec which encapsulates
15+
information required for performing the load.
16+
touch() - marks the give blocks as recently used. Can be used
17+
to track block's LRU. This function is separated from the
18+
prepare_load function to allow setting block recency even
19+
for blocks which do not need reading from the cache, such as
20+
blocks that are cached by the GPU prefix cache.
21+
complete_load() - mark blocks which were previously prepared to be
22+
loaded as done loading. This is to re-allow their eviction.
23+
prepare_store() - prepare the given blocks to be written.
24+
Returns a StoreSpec encapsulating offloading information,
25+
as well as a list of blocks that were evicted as a result.
26+
complete_store() - marks a previous store as completed.
27+
Following this call, the given blocks will become loadable.
28+
"""
29+
30+
from abc import ABC, abstractmethod
31+
from dataclasses import dataclass
32+
from typing import Optional
33+
34+
35+
class LoadStoreSpec(ABC):
36+
"""
37+
Abstract metadata that encapsulates information allowing a worker
38+
to load, and optionally also to store, a block of KV data.
39+
"""
40+
41+
@staticmethod
42+
@abstractmethod
43+
def medium() -> str:
44+
"""
45+
Returns a string representation of the medium type
46+
this store/load targets.
47+
"""
48+
pass
49+
50+
51+
@dataclass
52+
class PrepareStoreOutput:
53+
block_hashes_to_store: list[int]
54+
store_specs: list[LoadStoreSpec]
55+
block_hashes_evicted: list[int]
56+
57+
58+
class OffloadingManager(ABC):
59+
60+
@abstractmethod
61+
def lookup(self, block_hashes: list[int]) -> int:
62+
"""
63+
Finds the length of the maximal series of blocks, starting from the
64+
first one, that are all offloaded.
65+
66+
Args:
67+
block_hashes: the hashes identifying the blocks to lookup.
68+
69+
Returns:
70+
An integer representing the maximal number of blocks that
71+
are currently offloaded.
72+
"""
73+
pass
74+
75+
@abstractmethod
76+
def prepare_load(self, block_hashes: list[int]) -> list[LoadStoreSpec]:
77+
"""
78+
Prepare the given blocks to be read.
79+
The given blocks will be protected from eviction until
80+
complete_load is called.
81+
It assumes all given blocks are offloaded.
82+
83+
Args:
84+
block_hashes: the hashes identifying the blocks.
85+
86+
Returns:
87+
A list of LoadStoreSpec, one per each block, that can be used by
88+
a worker to locate and load the actual offloaded KV data.
89+
"""
90+
pass
91+
92+
@abstractmethod
93+
def touch(self, block_hashes: list[int]):
94+
"""
95+
Mark the given blocks as recently used.
96+
This could in practice mean moving them to the end of an LRU list.
97+
98+
Args:
99+
block_hashes: the hashes identifying the blocks.
100+
"""
101+
pass
102+
103+
@abstractmethod
104+
def complete_load(self, block_hashes: list[int]):
105+
"""
106+
Marks previous blocks that were prepared to load as done loading.
107+
108+
Args:
109+
block_hashes: the hashes identifying the blocks.
110+
"""
111+
pass
112+
113+
@abstractmethod
114+
def prepare_store(self,
115+
block_hashes: list[int]) -> Optional[PrepareStoreOutput]:
116+
"""
117+
Prepare the given blocks to be offloaded.
118+
The given blocks will be protected from eviction until
119+
complete_store is called.
120+
121+
Args:
122+
block_hashes: the hashes identifying the blocks.
123+
124+
Returns:
125+
A PrepareStoreOutput indicating which blocks need storing,
126+
where to store them (LoadStoreSpec), and list of blocks that
127+
were evicted as a result.
128+
None is returned if the blocks cannot be stored.
129+
"""
130+
pass
131+
132+
@abstractmethod
133+
def complete_store(self, block_hashes: list[int], success: bool = True):
134+
"""
135+
Marks blocks which were previously prepared to be stored, as stored.
136+
Following this call, the blocks become loadable.
137+
If if_success is False, blocks that were not marked as stored will be
138+
removed.
139+
140+
Args:
141+
block_hashes: the hashes identifying the blocks.
142+
success: whether the blocks were stored successfully.
143+
"""
144+
pass

vllm/v1/offloading/mediums.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from abc import ABC
4+
5+
from vllm.v1.offloading.abstract import LoadStoreSpec
6+
7+
8+
class BlockIDLoadStoreSpec(LoadStoreSpec, ABC):
9+
"""
10+
Spec for loading/storing a KV block from a given block number.
11+
"""
12+
13+
def __init__(self, block_id: int):
14+
self.block_id = block_id
15+
16+
def __repr__(self) -> str:
17+
return str(self.block_id)
18+
19+
20+
class GPULoadStoreSpec(BlockIDLoadStoreSpec):
21+
"""
22+
Spec for loading/storing a KV block to GPU memory.
23+
"""
24+
25+
@staticmethod
26+
def medium() -> str:
27+
return "GPU"
28+
29+
30+
class CPULoadStoreSpec(BlockIDLoadStoreSpec):
31+
"""
32+
Spec for loading/storing a KV block to CPU memory.
33+
"""
34+
35+
@staticmethod
36+
def medium() -> str:
37+
return "CPU"

0 commit comments

Comments
 (0)