Skip to content

Commit ea7ef63

Browse files
[PD-Disagg] Deduplicate common KVManager methods into CommonKVManager (sgl-project#19205)
Co-authored-by: Shangming Cai <csmthu@gmail.com>
1 parent 8aeb16f commit ea7ef63

File tree

4 files changed

+20
-66
lines changed

4 files changed

+20
-66
lines changed

python/sglang/srt/disaggregation/common/conn.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def __init__(
9999
logger.debug(f"kv manager bind to {zmq_bind_host}:{self.rank_port}")
100100

101101
self.request_status: Dict[int, KVPoll] = {}
102+
self.failure_records: Dict[int, str] = {}
103+
self.failure_lock = threading.Lock()
102104

103105
if self.disaggregation_mode == DisaggregationMode.PREFILL:
104106
self.register_to_bootstrap()
@@ -115,6 +117,24 @@ def __init__(
115117
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
116118
)
117119

120+
def check_status(self, bootstrap_room: int) -> KVPoll:
121+
return self.request_status[bootstrap_room]
122+
123+
def update_status(self, bootstrap_room: int, status: KVPoll):
124+
if bootstrap_room not in self.request_status:
125+
self.request_status[bootstrap_room] = status
126+
else:
127+
if status == KVPoll.Failed:
128+
self.request_status[bootstrap_room] = KVPoll.Failed
129+
else:
130+
self.request_status[bootstrap_room] = max(
131+
self.request_status[bootstrap_room], status
132+
)
133+
134+
def record_failure(self, bootstrap_room: int, failure_reason: str):
135+
with self.failure_lock:
136+
self.failure_records[bootstrap_room] = failure_reason
137+
118138
def ensure_parallel_info(self, bootstrap_addr: str) -> bool:
119139
"""Fetch and cache prefill parallel info if not yet available.
120140
Returns True if info is available (cached or freshly fetched).

python/sglang/srt/disaggregation/mooncake/conn.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import numpy as np
1515
import numpy.typing as npt
1616
import requests
17-
import zmq
1817

1918
from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
2019
from sglang.srt.disaggregation.common.conn import (
@@ -235,9 +234,6 @@ def __init__(
235234
# These timeout requests should be aborted to release the tree cache.
236235
self.waiting_timeout = envs.SGLANG_DISAGGREGATION_WAITING_TIMEOUT.get()
237236

238-
self.failure_records: Dict[int, str] = {}
239-
self.failure_lock = threading.Lock()
240-
241237
def init_engine(self):
242238
self.engine = get_mooncake_transfer_engine()
243239

@@ -1095,25 +1091,6 @@ def add_transfer_request(
10951091
)
10961092
)
10971093

1098-
def check_status(self, bootstrap_room: int):
1099-
return self.request_status[bootstrap_room]
1100-
1101-
def update_status(self, bootstrap_room: int, status: KVPoll):
1102-
if bootstrap_room not in self.request_status:
1103-
self.request_status[bootstrap_room] = status
1104-
else:
1105-
# NOTE: status is only allowed to be incremented unless it is KVPoll.Failed
1106-
if status == KVPoll.Failed:
1107-
self.request_status[bootstrap_room] = KVPoll.Failed
1108-
else:
1109-
self.request_status[bootstrap_room] = max(
1110-
self.request_status[bootstrap_room], status
1111-
)
1112-
1113-
def record_failure(self, bootstrap_room: int, failure_reason: str):
1114-
with self.failure_lock:
1115-
self.failure_records[bootstrap_room] = failure_reason
1116-
11171094
def get_session_id(self):
11181095
return self.engine.get_session_id()
11191096

@@ -1242,11 +1219,6 @@ def abort(self):
12421219

12431220

12441221
class MooncakeKVReceiver(CommonKVReceiver):
1245-
_ctx = zmq.Context()
1246-
_socket_cache = {}
1247-
_socket_locks = {}
1248-
_global_lock = threading.Lock()
1249-
12501222
def __init__(
12511223
self,
12521224
mgr: MooncakeKVManager,

python/sglang/srt/disaggregation/mori/conn.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,6 @@ def __init__(
191191
self.kv_mem_descs: List[MemoryDesc] = []
192192
self.aux_mem_descs: List[MemoryDesc] = []
193193
self.state_mem_descs: List[MemoryDesc] = []
194-
self.failure_records: Dict[int, str] = {}
195-
self.failure_lock = threading.Lock()
196194
self.transfer_lock = threading.Lock()
197195
self._register_local_buffers()
198196
if self.disaggregation_mode == DisaggregationMode.PREFILL:
@@ -293,24 +291,6 @@ def _register_local_buffers(self) -> None:
293291
)
294292
self.state_mem_descs.append(desc)
295293

296-
def check_status(self, bootstrap_room: int):
297-
return self.request_status[bootstrap_room]
298-
299-
def update_status(self, bootstrap_room: int, status: KVPoll):
300-
if bootstrap_room not in self.request_status:
301-
self.request_status[bootstrap_room] = status
302-
else:
303-
if status == KVPoll.Failed:
304-
self.request_status[bootstrap_room] = KVPoll.Failed
305-
else:
306-
self.request_status[bootstrap_room] = max(
307-
self.request_status[bootstrap_room], status
308-
)
309-
310-
def record_failure(self, bootstrap_room: int, failure_reason: str) -> None:
311-
with self.failure_lock:
312-
self.failure_records[bootstrap_room] = failure_reason
313-
314294
def _handle_register_message(self, payload: List[bytes]) -> None:
315295
try:
316296
register_info = KVArgsRegisterInfo.from_zmq(payload)

python/sglang/srt/disaggregation/nixl/conn.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -300,24 +300,6 @@ def _handle_node_failure(self, failed_bootstrap_addr):
300300
logger.error(f"Let room {room} be failed due to prefill down")
301301
self.update_status(room, KVPoll.Failed)
302302

303-
def check_status(self, bootstrap_room: int):
304-
return self.request_status[bootstrap_room]
305-
306-
def update_status(self, bootstrap_room: int, status: KVPoll):
307-
if bootstrap_room not in self.request_status:
308-
self.request_status[bootstrap_room] = status
309-
else:
310-
# NOTE: status is only allowed to be incremented unless it is KVPoll.Failed
311-
if status == KVPoll.Failed:
312-
self.request_status[bootstrap_room] = KVPoll.Failed
313-
else:
314-
self.request_status[bootstrap_room] = max(
315-
self.request_status[bootstrap_room], status
316-
)
317-
318-
def record_failure(self, bootstrap_room: int, failure_reason: str):
319-
pass
320-
321303
def register_buffer_to_engine(self):
322304
kv_addrs = []
323305
for kv_data_ptr, kv_data_len in zip(

0 commit comments

Comments
 (0)