30
30
class ReqMeta :
31
31
# Request Id
32
32
request_id : str
33
- # Request tokens
34
- token_ids : torch .Tensor
35
- # Slot mappings, should have the same length as token_ids
36
- slot_mapping : torch . Tensor
33
+ # Request block ids
34
+ block_ids : torch .Tensor
35
+ # Request num tokens
36
+ num_tokens : int
37
37
38
38
@staticmethod
39
39
def make_meta (request_id : str , token_ids : list [int ], block_ids : list [int ],
40
40
block_size : int ) -> "ReqMeta" :
41
- valid_num_tokens = len (token_ids )
42
- token_ids_tensor = torch .tensor (token_ids )
43
41
block_ids_tensor = torch .tensor (block_ids )
44
- num_blocks = block_ids_tensor .shape [0 ]
45
- block_offsets = torch .arange (0 , block_size )
46
- slot_mapping = block_offsets .reshape ((1 , block_size )) + \
47
- block_ids_tensor .reshape ((num_blocks , 1 )) * block_size
48
- slot_mapping = slot_mapping .flatten ()[:valid_num_tokens ]
49
-
50
42
return ReqMeta (
51
43
request_id = request_id ,
52
- token_ids = token_ids_tensor ,
53
- slot_mapping = slot_mapping ,
44
+ block_ids = block_ids_tensor ,
45
+ num_tokens = len ( token_ids ) ,
54
46
)
55
47
56
48
@@ -123,63 +115,58 @@ def start_load_kv(self, forward_context: "ForwardContext",
123
115
return
124
116
125
117
def inject_kv_into_layer (
126
- dst_kv_cache_layer : torch .Tensor ,
127
- src_kv_cache : torch .Tensor ,
128
- slot_mapping : torch .Tensor ,
118
+ layer : torch .Tensor ,
119
+ kv_cache : torch .Tensor ,
120
+ block_ids : torch .Tensor ,
129
121
request_id : str ,
130
122
) -> None :
131
- """Inject the KV cache into the layer.
123
+ """
124
+ Inject KV cache data into a given attention layer tensor.
125
+
126
+ This function updates `layer` in-place with values from `kv_cache`,
127
+ handling different backend layouts:
128
+ - MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
129
+ indexed along the first dimension.
130
+ - FlashAttention: KV tensors are indexed along the second
131
+ dimension.
132
+
133
+ If the number of provided block IDs does not match the number of KV
134
+ blocks, only the overlapping portion is updated, and a warning is
135
+ logged.
132
136
133
137
Args:
134
- dst_kv_cache_layer (torch.Tensor): the destination KV cache
135
- layer. In shape [2, num_pages, page_size, xxx] if not
136
- using MLA, [num_pages, page_size, xxx] otherwise.
137
- src_kv_cache (torch.Tensor): the source KV cache. In shape
138
- [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
139
- otherwise.
140
- slot_mapping (torch.Tensor): the slot mapping. In shape
141
- [num_tokens].
142
- request_id (str): request id for log
138
+ layer (torch.Tensor): The attention layer KV tensor to update.
139
+ kv_cache (torch.Tensor): The KV cache tensor to inject.
140
+ block_ids (torch.Tensor): Indices of the blocks to update.
141
+ request_id (str): Request identifier used for logging.
142
+
143
+ Returns:
144
+ None. The function modifies `layer` in-place.
143
145
"""
144
- dst_kv_cache_layer_shape = dst_kv_cache_layer .shape
145
- if isinstance (attn_metadata , MLACommonMetadata ):
146
- num_pages = dst_kv_cache_layer_shape [0 ]
147
- page_size = dst_kv_cache_layer_shape [1 ]
148
- dst_kv_cache_layer = dst_kv_cache_layer .reshape (
149
- num_pages * page_size , - 1 )
150
- self .check_tensors_except_dim (dst_kv_cache_layer , src_kv_cache ,
151
- 0 )
152
- num_token = src_kv_cache .shape [0 ]
153
- if len (slot_mapping ) == num_token :
154
- dst_kv_cache_layer [slot_mapping , ...] = src_kv_cache
146
+ if (isinstance (attn_metadata , MLACommonMetadata )
147
+ or layer .shape [1 ] == 2 ): # MLA or FlashInfer
148
+ num_block = kv_cache .shape [0 ]
149
+ self .check_tensors_except_dim (layer , kv_cache , 0 )
150
+ if len (block_ids ) == num_block :
151
+ layer [block_ids , ...] = kv_cache
155
152
else :
156
- dst_kv_cache_layer [slot_mapping [:num_token ],
157
- ...] = src_kv_cache
153
+ layer [block_ids [:num_block ], ...] = kv_cache
158
154
logger .warning (
159
- "🚧src_kv_cache does not match, num_slot:%d, "
160
- "num_token:%d, request_id:%s" , len (slot_mapping ),
161
- num_token , request_id )
162
-
163
- dst_kv_cache_layer .reshape (dst_kv_cache_layer_shape )
164
- else :
165
- num_pages = dst_kv_cache_layer_shape [1 ]
166
- page_size = dst_kv_cache_layer_shape [2 ]
167
- dst_kv_cache_layer = dst_kv_cache_layer .reshape (
168
- 2 , num_pages * page_size , - 1 )
169
- self .check_tensors_except_dim (dst_kv_cache_layer , src_kv_cache ,
170
- 1 )
171
- num_token = src_kv_cache .shape [1 ]
172
- if len (slot_mapping ) == num_token :
173
- dst_kv_cache_layer [:, slot_mapping , ...] = src_kv_cache
155
+ "🚧kv_cache does not match, block_ids:%d, "
156
+ "num_block:%d, request_id:%s" , len (block_ids ),
157
+ num_block , request_id )
158
+
159
+ elif layer .shape [0 ] == 2 : # FlashAttention
160
+ num_block = kv_cache .shape [1 ]
161
+ self .check_tensors_except_dim (layer , kv_cache , 1 )
162
+ if len (block_ids ) == num_block :
163
+ layer [:, block_ids , ...] = kv_cache
174
164
else :
175
- dst_kv_cache_layer [:, slot_mapping [:num_token ],
176
- ...] = src_kv_cache
165
+ layer [:, block_ids [:num_block ], ...] = kv_cache
177
166
logger .warning (
178
- "🚧src_kv_cache does not match, num_slot:%d, "
179
- "num_token:%d, request_id:%s" , len (slot_mapping ),
180
- num_token , request_id )
181
-
182
- dst_kv_cache_layer .reshape (dst_kv_cache_layer_shape )
167
+ "🚧kv_cache does not match, block_ids:%d, "
168
+ "num_block:%d, request_id:%s" , len (block_ids ),
169
+ num_block , request_id )
183
170
184
171
# Get the metadata
185
172
metadata : KVConnectorMetadata = \
@@ -201,19 +188,17 @@ def inject_kv_into_layer(
201
188
if kv_cache is None :
202
189
continue
203
190
204
- kv_cache_layer = kv_cache [ \
205
- forward_context .virtual_engine ]
191
+ layer = kv_cache [forward_context .virtual_engine ]
206
192
207
193
kv_cache = self .p2p_nccl_engine .recv_tensor (
208
194
request .request_id + "#" + layer_name )
209
195
210
196
if kv_cache is None :
211
- logger .warning ("🚧src_kv_cache is None, %s" ,
212
- request .request_id )
197
+ logger .warning ("🚧kv_cache is None, %s" , request .request_id )
213
198
continue
214
199
215
- inject_kv_into_layer (kv_cache_layer , kv_cache ,
216
- request .slot_mapping , request . request_id )
200
+ inject_kv_into_layer (layer , kv_cache , request . block_ids ,
201
+ request .request_id )
217
202
218
203
def wait_for_layer_load (self , layer_name : str ) -> None :
219
204
"""Blocking until the KV for a specific layer is loaded into vLLM's
@@ -247,20 +232,33 @@ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
247
232
248
233
def extract_kv_from_layer (
249
234
layer : torch .Tensor ,
250
- slot_mapping : torch .Tensor ,
235
+ block_ids : torch .Tensor ,
251
236
) -> torch .Tensor :
252
- """Extract the KV cache from the layer.
237
+ """
238
+ Extract KV cache slices from a given attention layer tensor.
239
+
240
+ This function handles multiple backend layouts:
241
+ - MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
242
+ indexed along the first dimension.
243
+ - FlashAttention: KV tensors are indexed along the second
244
+ dimension.
245
+
246
+ Args:
247
+ layer (torch.Tensor): The KV cache from the attention layer.
248
+ block_ids (torch.Tensor): Indices of blocks to extract.
253
249
254
- Assume the shape of the layer is (2, num_pages, page_size, xxx)
255
- if MLA is not used, and (num_pages, page_size, xxx) otherwise.
250
+ Returns:
251
+ torch.Tensor: A tensor containing the extracted KV slices.
252
+ Returns None if the layout is unsupported.
256
253
"""
257
- if isinstance (attn_metadata , MLACommonMetadata ):
258
- num_pages , page_size = layer .shape [0 ], layer .shape [1 ]
259
- return layer .reshape (num_pages * page_size , - 1 )[slot_mapping ,
260
- ...]
261
- num_pages , page_size = layer .shape [1 ], layer .shape [2 ]
262
- return layer .reshape (2 , num_pages * page_size , - 1 )[:, slot_mapping ,
263
- ...]
254
+ if (isinstance (attn_metadata , MLACommonMetadata )
255
+ or layer .shape [1 ] == 2 ): # MLA or FlashInfer
256
+ return layer [block_ids , ...]
257
+
258
+ if layer .shape [0 ] == 2 : # FlashAttention
259
+ return layer [:, block_ids , ...]
260
+
261
+ return None
264
262
265
263
connector_metadata = self ._get_connector_metadata ()
266
264
assert isinstance (connector_metadata , P2pNcclConnectorMetadata )
@@ -269,7 +267,7 @@ def extract_kv_from_layer(
269
267
ip , port = self .parse_request_id (request_id , True )
270
268
remote_address = ip + ":" + str (port + self ._rank )
271
269
272
- kv_cache = extract_kv_from_layer (kv_layer , request .slot_mapping )
270
+ kv_cache = extract_kv_from_layer (kv_layer , request .block_ids )
273
271
self .p2p_nccl_engine .send_tensor (request_id + "#" + layer_name ,
274
272
kv_cache , remote_address )
275
273
0 commit comments