Skip to content

Commit c10101a

Browse files
Csrayzivyilike
andauthored
[Bugfix] Fix several issues with p2p xPyD in GET type (#23993)
Signed-off-by: Csrayz <[email protected]> Signed-off-by: ivyilike <[email protected]> Co-authored-by: ivyilike <[email protected]>
1 parent ac24388 commit c10101a

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def inject_kv_into_layer(
178178

179179
# Load the KV for each request each layer
180180
for request in metadata.requests:
181+
request_id = request.request_id
182+
ip, port = self.parse_request_id(request_id, False)
183+
remote_address = ip + ":" + str(port + self._rank)
181184
for layer_name in forward_context.no_compile_layers:
182185
layer = forward_context.no_compile_layers[layer_name]
183186

@@ -191,7 +194,7 @@ def inject_kv_into_layer(
191194
layer = kv_cache[forward_context.virtual_engine]
192195

193196
kv_cache = self.p2p_nccl_engine.recv_tensor(
194-
request.request_id + "#" + layer_name)
197+
request.request_id + "#" + layer_name, remote_address)
195198

196199
if kv_cache is None:
197200
logger.warning("🚧kv_cache is None, %s", request.request_id)

vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ def __init__(self,
134134
# PUT or PUT_ASYNC
135135
# tensor_id: torch.Tensor
136136
self.send_queue: deque[SendQueueItem] = deque()
137-
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
138137
if self.send_type == "PUT_ASYNC":
139138
self._send_thread = threading.Thread(target=self.send_async,
140139
daemon=True)
@@ -143,6 +142,7 @@ def __init__(self,
143142
# tensor_id: torch.Tensor/(addr, dtype, shape)
144143
self.recv_store: dict[str, Any] = {}
145144
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
145+
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
146146
self.socks: dict[str, Any] = {} # remote_address: client socket
147147
self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank)
148148

@@ -223,18 +223,26 @@ def send_tensor(
223223
# GET
224224
with self.send_store_cv:
225225
tensor_size = tensor.element_size() * tensor.numel()
226+
if tensor_size > self.buffer_size_threshold:
227+
logger.warning(
228+
"❗[GET]tensor_id:%s, tensor_size:%d, is greater than"
229+
"buffer size threshold :%d, skip send to %s, rank:%d",
230+
tensor_id, tensor_size, self.buffer_size_threshold,
231+
remote_address, self.rank)
232+
return False
226233
while (self.buffer_size + tensor_size
227234
> self.buffer_size_threshold):
228-
oldest_tenser_id = next(iter(self.send_store))
229-
oldest_tenser = self.send_store.pop(oldest_tenser_id)
230-
oldest_tenser_size = oldest_tenser.element_size(
231-
) * oldest_tenser.numel()
232-
self.buffer_size -= oldest_tenser_size
233-
logger.info(
235+
assert len(self.send_store) > 0
236+
oldest_tensor_id = next(iter(self.send_store))
237+
oldest_tensor = self.send_store.pop(oldest_tensor_id)
238+
oldest_tensor_size = oldest_tensor.element_size(
239+
) * oldest_tensor.numel()
240+
self.buffer_size -= oldest_tensor_size
241+
logger.debug(
234242
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
235-
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
243+
" buffer_size:%d, oldest_tensor_size:%d, rank:%d",
236244
remote_address, tensor_id, tensor_size, self.buffer_size,
237-
oldest_tenser_size, self.rank)
245+
oldest_tensor_size, self.rank)
238246

239247
self.send_store[tensor_id] = tensor
240248
self.buffer_size += tensor_size

0 commit comments

Comments
 (0)