Skip to content

Commit e07bb41

Browse files
Abatomsimon-mo
authored andcommitted
[V1][P/D]P2pNcclConnector supports flashinfer (vllm-project#23536)
Signed-off-by: Abatom <[email protected]> Co-authored-by: Simon Mo <[email protected]>
1 parent 6a56bff commit e07bb41

File tree

1 file changed

+78
-80
lines changed

1 file changed

+78
-80
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py

Lines changed: 78 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -30,27 +30,19 @@
3030
class ReqMeta:
3131
# Request Id
3232
request_id: str
33-
# Request tokens
34-
token_ids: torch.Tensor
35-
# Slot mappings, should have the same length as token_ids
36-
slot_mapping: torch.Tensor
33+
# Request block ids
34+
block_ids: torch.Tensor
35+
# Request num tokens
36+
num_tokens: int
3737

3838
@staticmethod
3939
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
4040
block_size: int) -> "ReqMeta":
41-
valid_num_tokens = len(token_ids)
42-
token_ids_tensor = torch.tensor(token_ids)
4341
block_ids_tensor = torch.tensor(block_ids)
44-
num_blocks = block_ids_tensor.shape[0]
45-
block_offsets = torch.arange(0, block_size)
46-
slot_mapping = block_offsets.reshape((1, block_size)) + \
47-
block_ids_tensor.reshape((num_blocks, 1)) * block_size
48-
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
49-
5042
return ReqMeta(
5143
request_id=request_id,
52-
token_ids=token_ids_tensor,
53-
slot_mapping=slot_mapping,
44+
block_ids=block_ids_tensor,
45+
num_tokens=len(token_ids),
5446
)
5547

5648

@@ -123,63 +115,58 @@ def start_load_kv(self, forward_context: "ForwardContext",
123115
return
124116

125117
def inject_kv_into_layer(
126-
dst_kv_cache_layer: torch.Tensor,
127-
src_kv_cache: torch.Tensor,
128-
slot_mapping: torch.Tensor,
118+
layer: torch.Tensor,
119+
kv_cache: torch.Tensor,
120+
block_ids: torch.Tensor,
129121
request_id: str,
130122
) -> None:
131-
"""Inject the KV cache into the layer.
123+
"""
124+
Inject KV cache data into a given attention layer tensor.
125+
126+
This function updates `layer` in-place with values from `kv_cache`,
127+
handling different backend layouts:
128+
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
129+
indexed along the first dimension.
130+
- FlashAttention: KV tensors are indexed along the second
131+
dimension.
132+
133+
If the number of provided block IDs does not match the number of KV
134+
blocks, only the overlapping portion is updated, and a warning is
135+
logged.
132136
133137
Args:
134-
dst_kv_cache_layer (torch.Tensor): the destination KV cache
135-
layer. In shape [2, num_pages, page_size, xxx] if not
136-
using MLA, [num_pages, page_size, xxx] otherwise.
137-
src_kv_cache (torch.Tensor): the source KV cache. In shape
138-
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
139-
otherwise.
140-
slot_mapping (torch.Tensor): the slot mapping. In shape
141-
[num_tokens].
142-
request_id (str): request id for log
138+
layer (torch.Tensor): The attention layer KV tensor to update.
139+
kv_cache (torch.Tensor): The KV cache tensor to inject.
140+
block_ids (torch.Tensor): Indices of the blocks to update.
141+
request_id (str): Request identifier used for logging.
142+
143+
Returns:
144+
None. The function modifies `layer` in-place.
143145
"""
144-
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
145-
if isinstance(attn_metadata, MLACommonMetadata):
146-
num_pages = dst_kv_cache_layer_shape[0]
147-
page_size = dst_kv_cache_layer_shape[1]
148-
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
149-
num_pages * page_size, -1)
150-
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
151-
0)
152-
num_token = src_kv_cache.shape[0]
153-
if len(slot_mapping) == num_token:
154-
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
146+
if (isinstance(attn_metadata, MLACommonMetadata)
147+
or layer.shape[1] == 2): # MLA or FlashInfer
148+
num_block = kv_cache.shape[0]
149+
self.check_tensors_except_dim(layer, kv_cache, 0)
150+
if len(block_ids) == num_block:
151+
layer[block_ids, ...] = kv_cache
155152
else:
156-
dst_kv_cache_layer[slot_mapping[:num_token],
157-
...] = src_kv_cache
153+
layer[block_ids[:num_block], ...] = kv_cache
158154
logger.warning(
159-
"🚧src_kv_cache does not match, num_slot:%d, "
160-
"num_token:%d, request_id:%s", len(slot_mapping),
161-
num_token, request_id)
162-
163-
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
164-
else:
165-
num_pages = dst_kv_cache_layer_shape[1]
166-
page_size = dst_kv_cache_layer_shape[2]
167-
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
168-
2, num_pages * page_size, -1)
169-
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
170-
1)
171-
num_token = src_kv_cache.shape[1]
172-
if len(slot_mapping) == num_token:
173-
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
155+
"🚧kv_cache does not match, block_ids:%d, "
156+
"num_block:%d, request_id:%s", len(block_ids),
157+
num_block, request_id)
158+
159+
elif layer.shape[0] == 2: # FlashAttention
160+
num_block = kv_cache.shape[1]
161+
self.check_tensors_except_dim(layer, kv_cache, 1)
162+
if len(block_ids) == num_block:
163+
layer[:, block_ids, ...] = kv_cache
174164
else:
175-
dst_kv_cache_layer[:, slot_mapping[:num_token],
176-
...] = src_kv_cache
165+
layer[:, block_ids[:num_block], ...] = kv_cache
177166
logger.warning(
178-
"🚧src_kv_cache does not match, num_slot:%d, "
179-
"num_token:%d, request_id:%s", len(slot_mapping),
180-
num_token, request_id)
181-
182-
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
167+
"🚧kv_cache does not match, block_ids:%d, "
168+
"num_block:%d, request_id:%s", len(block_ids),
169+
num_block, request_id)
183170

184171
# Get the metadata
185172
metadata: KVConnectorMetadata = \
@@ -201,19 +188,17 @@ def inject_kv_into_layer(
201188
if kv_cache is None:
202189
continue
203190

204-
kv_cache_layer = kv_cache[ \
205-
forward_context.virtual_engine]
191+
layer = kv_cache[forward_context.virtual_engine]
206192

207193
kv_cache = self.p2p_nccl_engine.recv_tensor(
208194
request.request_id + "#" + layer_name)
209195

210196
if kv_cache is None:
211-
logger.warning("🚧src_kv_cache is None, %s",
212-
request.request_id)
197+
logger.warning("🚧kv_cache is None, %s", request.request_id)
213198
continue
214199

215-
inject_kv_into_layer(kv_cache_layer, kv_cache,
216-
request.slot_mapping, request.request_id)
200+
inject_kv_into_layer(layer, kv_cache, request.block_ids,
201+
request.request_id)
217202

218203
def wait_for_layer_load(self, layer_name: str) -> None:
219204
"""Blocking until the KV for a specific layer is loaded into vLLM's
@@ -247,20 +232,33 @@ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
247232

248233
def extract_kv_from_layer(
249234
layer: torch.Tensor,
250-
slot_mapping: torch.Tensor,
235+
block_ids: torch.Tensor,
251236
) -> torch.Tensor:
252-
"""Extract the KV cache from the layer.
237+
"""
238+
Extract KV cache slices from a given attention layer tensor.
239+
240+
This function handles multiple backend layouts:
241+
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
242+
indexed along the first dimension.
243+
- FlashAttention: KV tensors are indexed along the second
244+
dimension.
245+
246+
Args:
247+
layer (torch.Tensor): The KV cache from the attention layer.
248+
block_ids (torch.Tensor): Indices of blocks to extract.
253249
254-
Assume the shape of the layer is (2, num_pages, page_size, xxx)
255-
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
250+
Returns:
251+
torch.Tensor: A tensor containing the extracted KV slices.
252+
Returns None if the layout is unsupported.
256253
"""
257-
if isinstance(attn_metadata, MLACommonMetadata):
258-
num_pages, page_size = layer.shape[0], layer.shape[1]
259-
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
260-
...]
261-
num_pages, page_size = layer.shape[1], layer.shape[2]
262-
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
263-
...]
254+
if (isinstance(attn_metadata, MLACommonMetadata)
255+
or layer.shape[1] == 2): # MLA or FlashInfer
256+
return layer[block_ids, ...]
257+
258+
if layer.shape[0] == 2: # FlashAttention
259+
return layer[:, block_ids, ...]
260+
261+
return None
264262

265263
connector_metadata = self._get_connector_metadata()
266264
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
@@ -269,7 +267,7 @@ def extract_kv_from_layer(
269267
ip, port = self.parse_request_id(request_id, True)
270268
remote_address = ip + ":" + str(port + self._rank)
271269

272-
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
270+
kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
273271
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
274272
kv_cache, remote_address)
275273

0 commit comments

Comments
 (0)