Skip to content

Commit 460d02a

Browse files
authored
[NIXL] Fix after virtual block_size for host_buffer with heter kv_layout (#29122)
Signed-off-by: Chendi Xue <[email protected]>
1 parent b4c8fba commit 460d02a

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1042,10 +1042,12 @@ def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> Non
10421042
NOT directly supported by NIXL (e.g., tpu)
10431043
"""
10441044
xfer_buffers: dict[str, torch.Tensor] = {}
1045+
inv_order = [0, 1, 3, 2, 4]
10451046
try:
10461047
for layer_name, kv_cache in kv_caches.items():
10471048
kv_shape = kv_cache.shape
10481049
kv_dtype = kv_cache.dtype
1050+
permute_shape = False
10491051
if (
10501052
self.kv_cache_layout == "NHD"
10511053
and self.vllm_config.kv_transfer_config is not None
@@ -1059,10 +1061,20 @@ def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> Non
10591061
# Since NHD will not support Decode/Prefill TP_ratio > 1,
10601062
# we can leverage host_buffer for permute
10611063
self.host_buffer_kv_cache_layout = "HND"
1062-
kv_shape = tuple(kv_shape[i] for i in [0, 1, 3, 2, 4])
1064+
kv_shape = (
1065+
tuple(kv_shape[i] for i in inv_order)
1066+
if not self.use_mla
1067+
else kv_shape
1068+
)
1069+
permute_shape = not self.use_mla
1070+
10631071
xfer_buffers[layer_name] = torch.empty(
10641072
kv_shape, dtype=kv_dtype, device="cpu"
10651073
)
1074+
if permute_shape:
1075+
xfer_buffers[layer_name] = xfer_buffers[layer_name].permute(
1076+
inv_order
1077+
)
10661078
except MemoryError as e:
10671079
logger.error("NIXLConnectorWorker gets %s.", e)
10681080
raise

vllm/platforms/xpu.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,6 @@ def insert_blocks_to_device(
251251
) -> None:
252252
"""Copy blocks from src_cache to dst_cache on XPU."""
253253
_src_cache = src_cache[:, src_block_indices]
254-
if _src_cache.shape[2:] != dst_cache.shape[2:]:
255-
# To support TP_ratio, HOST KV might be initiated with HND
256-
# while XPU device KV is with NHD
257-
_src_cache = _src_cache.permute(0, 1, 3, 2, 4)
258254
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
259255

260256
@classmethod
@@ -267,8 +263,4 @@ def swap_out_blocks_to_host(
267263
) -> None:
268264
"""Copy blocks from XPU to host (CPU)."""
269265
_src_cache = src_cache[:, src_block_indices]
270-
if _src_cache.shape[2:] != dst_cache.shape[2:]:
271-
# XPU device KV is with NHD while HOST KV
272-
# might be initiated with HND for TP_ratio support
273-
_src_cache = _src_cache.permute(0, 1, 3, 2, 4)
274266
dst_cache[:, dst_block_indices] = _src_cache.cpu()

0 commit comments

Comments
 (0)