Skip to content

Commit 63d9ede

Browse files
authored
[Feature] kv_transfer/kv_connector: Add aibrix_pd_reuse_connector to support PD + reuse (#1852)
* Add SHFS (Shared File System) for PD + reuse test Signed-off-by: Dengcheng Zhu <[email protected]> * kv_transfer/kv_connector: Add aibrix_pd_reuse_connector to support PD + reuse Signed-off-by: Dengcheng Zhu <[email protected]> * Add AIBrixPDReuseConnector test yaml files to verify & compare TPOT Signed-off-by: Dengcheng Zhu <[email protected]> --------- Signed-off-by: Dengcheng Zhu <[email protected]>
1 parent b781a8b commit 63d9ede

File tree

7 files changed

+2167
-0
lines changed

7 files changed

+2167
-0
lines changed

python/aibrix_kvcache/aibrix_kvcache/envs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@
146146
# EIC Config
147147
AIBRIX_KV_CACHE_OL_EIC_CONFIG_FILE: str = ""
148148

149+
# SHFS (Shared File System) Env Vars
150+
AIBRIX_KV_CACHE_OL_SHFS_ROOT: str = os.path.expanduser(
151+
os.path.join(os.path.expanduser("~"), ".kv_cache_ol", "shfs")
152+
)
153+
149154
# The begin-* and end* here are used by the documentation generator
150155
# to extract the used env vars.
151156

@@ -427,6 +432,13 @@
427432
"AIBRIX_KV_CACHE_OL_EIC_CONFIG_FILE": lambda: (
428433
os.getenv("AIBRIX_KV_CACHE_OL_EIC_CONFIG_FILE", "").strip()
429434
),
435+
# ================== SHFS Env Vars ==================
436+
"AIBRIX_KV_CACHE_OL_SHFS_ROOT": lambda: os.path.expanduser(
437+
os.getenv(
438+
"AIBRIX_KV_CACHE_OL_SHFS_ROOT",
439+
os.path.join(os.path.expanduser("~"), ".kv_cache_ol", "shfs"),
440+
)
441+
),
430442
}
431443

432444
# end-env-vars-definition

python/aibrix_kvcache/aibrix_kvcache/l2/connectors/connector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def create(
112112
from .eic import EICConnector
113113

114114
return EICConnector.from_envs(conn_id, executor, **kwargs)
115+
elif backend_name == "SHFS":
116+
from .shfs import SHFSConnector
117+
118+
return SHFSConnector.from_envs(conn_id, executor, **kwargs)
115119
else:
116120
raise ValueError(f"Unknown connector type: {backend_name}")
117121

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
# Copyright 2024 The Aibrix Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from concurrent.futures import Executor
17+
from pathlib import Path
18+
from typing import Sequence
19+
20+
import torch
21+
22+
from ... import envs
23+
from ...common import AsyncBase
24+
from ...common.absl_logging import getLogger
25+
from ...memory import MemoryRegion
26+
from ...status import Status, StatusCodes
27+
from ...utils import ensure_dir_exist
28+
from . import Connector, ConnectorFeature
29+
30+
logger = getLogger(__name__)
31+
32+
33+
@AsyncBase.async_wrap(
34+
exists="_exists", get="_get", put="_put", delete="_delete"
35+
)
36+
class SHFSConnector(Connector[bytes, torch.Tensor], AsyncBase):
37+
"""Shared File System (SHFS) connector for KVCache L2 storage.
38+
39+
This connector stores KVCache blocks as files in a shared file system.
40+
All prefiller and decoder vllm engines can access the same shared
41+
directory to store and retrieve KVCache blocks.
42+
"""
43+
44+
def __init__(
45+
self,
46+
root_path: str,
47+
executor: Executor,
48+
):
49+
super().__init__(executor)
50+
self.root_path = Path(root_path)
51+
self.conn_id: str | None = None # Will be set in from_envs
52+
53+
@classmethod
54+
def from_envs(
55+
cls, conn_id: str, executor: Executor, **kwargs
56+
) -> "SHFSConnector":
57+
"""Create a connector from environment variables."""
58+
root = envs.AIBRIX_KV_CACHE_OL_SHFS_ROOT
59+
60+
# Create full path: root/conn_id
61+
full_path = os.path.join(os.path.expanduser(root), conn_id)
62+
63+
instance = cls(full_path, executor)
64+
instance.conn_id = conn_id
65+
return instance
66+
67+
@property
68+
def name(self) -> str:
69+
return "SHFS"
70+
71+
@property
72+
def feature(self) -> ConnectorFeature:
73+
"""SHFS connector features."""
74+
feature = ConnectorFeature(mput_mget=True)
75+
return feature
76+
77+
def __del__(self) -> None:
78+
self.close()
79+
80+
def _key_to_filepath(self, key: bytes) -> Path:
81+
"""Convert a key (bytes) to a file path.
82+
83+
Args:
84+
key: The cache key as bytes.
85+
86+
Returns:
87+
Path object for the file.
88+
"""
89+
key_hex = key.hex()
90+
# Create a two-level directory structure to avoid too many files in one
91+
# directory. Use first 2 chars and next 2 chars as subdirectories
92+
if len(key_hex) >= 4:
93+
subdir1 = key_hex[:2]
94+
subdir2 = key_hex[2:4]
95+
filepath = self.root_path / subdir1 / subdir2 / key_hex
96+
else:
97+
# Fallback for very short keys
98+
filepath = self.root_path / key_hex
99+
100+
return filepath
101+
102+
@Status.capture_exception
103+
def open(self) -> Status:
104+
"""Open a connection by ensuring the root directory exists."""
105+
try:
106+
ensure_dir_exist(str(self.root_path))
107+
return Status.ok()
108+
except Exception as e:
109+
logger.error(f"SHFS open() failed: {e}")
110+
return Status(
111+
StatusCodes.ERROR, f"Failed to create root directory: {e}"
112+
)
113+
114+
@Status.capture_exception
115+
def close(self) -> Status:
116+
"""Close a connection."""
117+
return Status.ok()
118+
119+
def get_batches(
120+
self,
121+
keys: Sequence[bytes],
122+
mrs: Sequence[MemoryRegion | Sequence[MemoryRegion]],
123+
batch_size: int,
124+
) -> Sequence[
125+
Sequence[tuple[bytes, MemoryRegion | Sequence[MemoryRegion]]]
126+
]:
127+
"""Get batches for mput/mget operations."""
128+
batches = []
129+
current_batch = []
130+
131+
for key, mr in zip(keys, mrs):
132+
current_batch.append((key, mr))
133+
if len(current_batch) >= batch_size:
134+
batches.append(current_batch)
135+
current_batch = []
136+
137+
if current_batch:
138+
batches.append(current_batch)
139+
140+
return batches
141+
142+
@Status.capture_exception
143+
async def mget(
144+
self,
145+
keys: Sequence[bytes],
146+
mrs: Sequence[MemoryRegion | Sequence[MemoryRegion]],
147+
) -> Sequence[Status]:
148+
"""MGet a list of values."""
149+
statuses = []
150+
151+
for i, (key, mr) in enumerate(zip(keys, mrs)):
152+
status = await self.get(key, mr)
153+
statuses.append(status)
154+
if not status.is_ok() and not status.is_not_found():
155+
logger.error(f"SHFS mget[{i}] failed: {status}")
156+
157+
return statuses
158+
159+
@Status.capture_exception
160+
async def mput(
161+
self,
162+
keys: Sequence[bytes],
163+
mrs: Sequence[MemoryRegion | Sequence[MemoryRegion]],
164+
) -> Sequence[Status]:
165+
"""MPut a list of key value pairs."""
166+
statuses = []
167+
168+
for i, (key, mr) in enumerate(zip(keys, mrs)):
169+
status = await self.put(key, mr)
170+
statuses.append(status)
171+
if not status.is_ok():
172+
logger.error(f"SHFS mput[{i}] failed: {status}")
173+
174+
return statuses
175+
176+
@Status.capture_exception
177+
def _exists(self, key: bytes) -> Status:
178+
"""Check if key is in the store."""
179+
filepath = self._key_to_filepath(key)
180+
181+
try:
182+
if filepath.exists():
183+
return Status.ok()
184+
else:
185+
return Status(StatusCodes.NOT_FOUND)
186+
except Exception as e:
187+
logger.error(f"SHFS exists failed: {e}")
188+
return Status(StatusCodes.ERROR, f"Failed to check existence: {e}")
189+
190+
@Status.capture_exception
191+
def _get(
192+
self,
193+
key: bytes,
194+
mr: MemoryRegion | Sequence[MemoryRegion],
195+
) -> Status:
196+
"""Get a value."""
197+
if isinstance(mr, Sequence):
198+
# For sequence of MRs, we need to handle differently
199+
# For now, assume single MR
200+
if len(mr) != 1:
201+
logger.error(
202+
f"SHFS get: Sequence MR with {len(mr)} elements unsupported"
203+
)
204+
return Status(
205+
StatusCodes.ERROR,
206+
"Sequence MR with multiple elements unsupported",
207+
)
208+
mr = mr[0]
209+
210+
filepath = self._key_to_filepath(key)
211+
212+
try:
213+
if not filepath.exists():
214+
return Status(StatusCodes.NOT_FOUND)
215+
216+
file_size = filepath.stat().st_size
217+
218+
if file_size != mr.length:
219+
logger.error(
220+
f"SHFS get: file size mismatch: {file_size} != {mr.length}"
221+
)
222+
return Status(
223+
StatusCodes.ERROR,
224+
f"File size mismatch: {file_size} != {mr.length}",
225+
)
226+
227+
with open(filepath, "rb") as f:
228+
data = f.read()
229+
230+
if len(data) != mr.length:
231+
logger.error(
232+
f"SHFS get: data size mismatch: {len(data)} != {mr.length}"
233+
)
234+
return Status(
235+
StatusCodes.ERROR,
236+
f"Data size mismatch: {len(data)} != {mr.length}",
237+
)
238+
239+
mr.fill(data)
240+
return Status.ok()
241+
242+
except Exception as e:
243+
logger.error(f"SHFS get failed: {e}")
244+
return Status(StatusCodes.ERROR, f"Failed to get value: {e}")
245+
246+
@Status.capture_exception
247+
def _put(
248+
self,
249+
key: bytes,
250+
mr: MemoryRegion | Sequence[MemoryRegion],
251+
) -> Status:
252+
"""Put a key value pair."""
253+
if isinstance(mr, Sequence):
254+
# For sequence of MRs, we need to handle differently
255+
# For now, assume single MR
256+
if len(mr) != 1:
257+
logger.error(
258+
f"SHFS put: Sequence MR with {len(mr)} elements unsupported"
259+
)
260+
return Status(
261+
StatusCodes.ERROR,
262+
"Sequence MR with multiple elements unsupported",
263+
)
264+
mr = mr[0]
265+
266+
filepath = self._key_to_filepath(key)
267+
268+
try:
269+
filepath.parent.mkdir(parents=True, exist_ok=True)
270+
271+
data = mr.tobytes()
272+
273+
# Write atomically: write to temp file first, then rename
274+
temp_filepath = filepath.with_suffix(filepath.suffix + ".tmp")
275+
276+
with open(temp_filepath, "wb") as f:
277+
bytes_written = f.write(data)
278+
279+
if bytes_written != len(data):
280+
msg = f"Incomplete write: {bytes_written} != {len(data)}"
281+
logger.error(f"SHFS put: {msg.lower()}")
282+
temp_filepath.unlink(missing_ok=True)
283+
return Status(StatusCodes.ERROR, msg)
284+
285+
temp_filepath.replace(filepath)
286+
287+
if filepath.exists():
288+
actual_size = filepath.stat().st_size
289+
if actual_size != len(data):
290+
msg = (
291+
f"File size mismatch after write: "
292+
f"{actual_size} != {len(data)}"
293+
)
294+
logger.error(f"SHFS put: {msg.lower()}")
295+
return Status(StatusCodes.ERROR, msg)
296+
297+
return Status.ok()
298+
299+
except Exception as e:
300+
logger.error(f"SHFS put failed: {e}")
301+
temp_filepath = filepath.with_suffix(filepath.suffix + ".tmp")
302+
temp_filepath.unlink(missing_ok=True)
303+
return Status(StatusCodes.ERROR, f"Failed to put value: {e}")
304+
305+
@Status.capture_exception
306+
def _delete(self, key: bytes) -> Status:
307+
"""Delete a key."""
308+
filepath = self._key_to_filepath(key)
309+
310+
try:
311+
if filepath.exists():
312+
filepath.unlink()
313+
return Status.ok()
314+
else:
315+
return Status(StatusCodes.NOT_FOUND)
316+
except Exception as e:
317+
logger.error(f"SHFS delete failed: {e}")
318+
return Status(StatusCodes.ERROR, f"Failed to delete key: {e}")
319+
320+
def register_slabs(self, slabs: list[torch.Tensor]) -> Status:
321+
"""Register slabs with backend-specific register function.
322+
323+
SHFS doesn't need to register slabs since it uses file I/O.
324+
"""
325+
# No-op for SHFS
326+
return Status.ok()

0 commit comments

Comments
 (0)