Skip to content

Commit 72f431e

Browse files
authored
[Nixl] Minor refactor to handshake related metadata (#26410)
Signed-off-by: NickLucche <[email protected]>
1 parent be44450 commit 72f431e

File tree

2 files changed

+176
-88
lines changed

2 files changed

+176
-88
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,8 +565,6 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
565565
kv_cache_layout=mismatched_layout,
566566
)
567567

568-
# We don't check layout for homogeneous TP and MLA for now, as the
569-
# whole block is moved.
570568
with pytest.raises(RuntimeError):
571569
# mismatched layout is expected to fail
572570
worker.add_remote_agent(meta, remote_tp_size=2)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 176 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
get_tensor_model_parallel_world_size,
3737
get_tp_group,
3838
)
39-
from vllm.distributed.utils import divide
4039
from vllm.forward_context import ForwardContext
4140
from vllm.logger import init_logger
4241
from vllm.platforms import current_platform
@@ -521,6 +520,72 @@ def request_finished(
521520
class NixlConnectorWorker:
522521
"""Implementation of Worker side methods"""
523522

523+
@dataclass
524+
class TpKVTopology:
525+
"""
526+
Helper class for tensor parallel and KV topology information for
527+
mapping between local and remote TP workers.
528+
"""
529+
530+
tp_size: int
531+
tp_rank: int
532+
remote_tp_size: dict[EngineId, int]
533+
is_mla: bool
534+
total_num_kv_heads: int
535+
536+
def tp_ratio(
537+
self,
538+
remote_tp_size: int,
539+
) -> int:
540+
"""
541+
Calculate the tensor parallel ratio between local and remote TP.
542+
We can think of it as the number of local TP workers-per-remote TP
543+
workers. Local workers will read from the same remote TP worker in
544+
groups of size `tp_ratio`.
545+
"""
546+
assert self.tp_size % remote_tp_size == 0, (
547+
f"Local tensor parallel size {self.tp_size} is not divisible "
548+
f"by remote tensor parallel size {remote_tp_size}."
549+
)
550+
return self.tp_size // remote_tp_size
551+
552+
def tp_ratio_from_engine_id(
553+
self,
554+
remote_engine_id: EngineId,
555+
) -> int:
556+
remote_tp_size = self.remote_tp_size[remote_engine_id]
557+
return self.tp_ratio(remote_tp_size)
558+
559+
def is_kv_replicated(self, engine_id: EngineId) -> bool:
560+
"""
561+
Whether the KV cache is replicated across TP workers due to the
562+
number of TP workers being greater than the number of KV heads.
563+
"""
564+
tp_size = self.remote_tp_size[engine_id]
565+
return tp_size // self.total_num_kv_heads >= 1
566+
567+
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
568+
# MLA is always replicated as the hidden dim can't be split.
569+
return self.is_mla or self.is_kv_replicated(remote_engine_id)
570+
571+
def get_target_remote_rank(
572+
self,
573+
remote_tp_size: int,
574+
) -> int:
575+
"""
576+
Get the remote TP rank (on P) that the current local TP rank
577+
(on D) will read from.
578+
"""
579+
tp_ratio = self.tp_ratio(remote_tp_size)
580+
return self.tp_rank // tp_ratio
581+
582+
def get_target_remote_rank_from_engine_id(
583+
self,
584+
remote_engine_id: EngineId,
585+
) -> int:
586+
remote_tp_size = self.remote_tp_size[remote_engine_id]
587+
return self.get_target_remote_rank(remote_tp_size)
588+
524589
def __init__(self, vllm_config: VllmConfig, engine_id: str):
525590
if NixlWrapper is None:
526591
logger.error("NIXL is not available")
@@ -534,6 +599,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
534599

535600
if vllm_config.kv_transfer_config is None:
536601
raise ValueError("kv_transfer_config must be set for NixlConnector")
602+
self.kv_transfer_config = vllm_config.kv_transfer_config
537603

538604
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
539605
"backends", ["UCX"]
@@ -654,7 +720,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
654720
# Protects _handshake_futures and _remote_agents.
655721
self._handshake_lock = threading.RLock()
656722

657-
self.vllm_config = vllm_config
658723
self.block_size = vllm_config.cache_config.block_size
659724
self.model_config = vllm_config.model_config
660725
self.cache_config = vllm_config.cache_config
@@ -686,6 +751,14 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
686751
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
687752
self.xfer_stats = NixlKVConnectorStats()
688753

754+
self.kv_topo = self.TpKVTopology(
755+
tp_size=self.world_size,
756+
tp_rank=self.tp_rank,
757+
remote_tp_size=self._tp_size, # shared state
758+
is_mla=self.use_mla,
759+
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
760+
)
761+
689762
@staticmethod
690763
def _nixl_handshake_listener(
691764
metadata: NixlAgentMetadata,
@@ -731,8 +804,7 @@ def _nixl_handshake(
731804

732805
# Handshake only with the remote TP rank that current local rank will
733806
# pull from. With homogeneous TP it happens to be the same rank_i.
734-
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
735-
p_remote_rank = self.tp_rank // tp_ratio
807+
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
736808
path = make_zmq_path("tcp", host, port + p_remote_rank)
737809
logger.debug(
738810
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank
@@ -989,13 +1061,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
9891061

9901062
# TODO(mgoin): Hybrid memory allocator is currently disabled for
9911063
# models with local attention (Llama 4). Can remove this once enabled.
992-
if self.vllm_config.model_config.hf_config.model_type == "llama4":
1064+
if self.model_config.hf_config.model_type == "llama4":
9931065
from transformers import Llama4TextConfig
9941066

995-
assert isinstance(
996-
self.vllm_config.model_config.hf_text_config, Llama4TextConfig
997-
)
998-
llama4_config = self.vllm_config.model_config.hf_text_config
1067+
assert isinstance(self.model_config.hf_text_config, Llama4TextConfig)
1068+
llama4_config = self.model_config.hf_text_config
9991069
no_rope_layers = llama4_config.no_rope_layers
10001070
chunk_size = llama4_config.attention_chunk_size
10011071
chunk_block_size = math.ceil(chunk_size / self.block_size)
@@ -1078,36 +1148,106 @@ def add_remote_agent(
10781148
engine_id = nixl_agent_meta.engine_id
10791149
# TODO re-evaluate refreshing for scaling/recovery
10801150
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
1151+
logger.debug(
1152+
"Remote agent with engine_id %s and rank"
1153+
"%s already exchanged metadata, skip handshake.",
1154+
engine_id,
1155+
remote_tp_rank,
1156+
)
10811157
return self._remote_agents[engine_id][remote_tp_rank]
10821158

1159+
### Register remote agent metadata
10831160
if engine_id not in self._tp_size:
10841161
self._tp_size[engine_id] = remote_tp_size
1085-
else:
1086-
assert self._tp_size[engine_id] == remote_tp_size
1087-
# TODO We may eventually want to skip enforcing the same attn backend.
1088-
assert nixl_agent_meta.attn_backend_name == self.backend_name
10891162

10901163
remote_agent_name = self.nixl_wrapper.add_remote_agent(
10911164
nixl_agent_meta.agent_metadata
10921165
)
10931166

1167+
# Handle tp_size>num_kv_heads: replicate KV cache.
1168+
replicates_kv_cache = self.kv_topo.replicates_kv_cache(engine_id)
1169+
1170+
# Create dst descs and xfer side handles. TP workers have same #blocks
1171+
# so we only register once per engine_id.
1172+
if engine_id not in self.dst_num_blocks:
1173+
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
1174+
1175+
# Keep track of remote agent kv caches base addresses.
1176+
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr
1177+
1178+
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
1179+
10941180
# Number of D TP workers reading from a single P TP worker. This is
10951181
# 1 when P and D `--tensor-parallel-size` match.
1096-
tp_ratio = divide(self._tp_size[self.engine_id], self._tp_size[engine_id])
1182+
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id)
1183+
1184+
### Register remote agent memory regions
1185+
blocks_data = []
1186+
# With homogeneous TP, D pulls the whole kv cache from corresponding
1187+
# rank. With heterogeneous TP, prepare the descriptors by splitting the
1188+
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
1189+
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
1190+
1191+
# Register all remote blocks, but only the corresponding kv heads.
1192+
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
1193+
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
1194+
rank_offset = (
1195+
self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0
1196+
)
1197+
for block_id in range(nixl_agent_meta.num_blocks):
1198+
block_offset = block_id * nixl_agent_meta.block_lens[i]
1199+
# For each block, grab the heads chunk belonging to rank_i
1200+
# of size remote_nheads // tp_ratio, which correspond to
1201+
# self.block_len == remote_block_len//tp_ratio bytes.
1202+
addr = base_addr + block_offset + rank_offset
1203+
# (addr, len, device id)
1204+
blocks_data.append((addr, kv_block_len, remote_tp_rank))
1205+
1206+
if self._use_flashinfer:
1207+
# With FlashInfer index V separately to allow head splitting.
1208+
for block_id in range(nixl_agent_meta.num_blocks):
1209+
block_offset = block_id * nixl_agent_meta.block_lens[i]
1210+
addr = base_addr + block_offset + rank_offset
1211+
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
1212+
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
1213+
1214+
logger.debug(
1215+
"Created %s blocks for dst engine %s with remote rank %s and local rank %s",
1216+
len(blocks_data),
1217+
engine_id,
1218+
remote_tp_rank,
1219+
self.tp_rank,
1220+
)
1221+
1222+
# Register with NIXL.
1223+
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
1224+
self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist(
1225+
remote_agent_name, descs
1226+
)
1227+
1228+
return remote_agent_name
1229+
1230+
def _validate_remote_agent_handshake(
1231+
self, nixl_agent_meta: NixlAgentMetadata, remote_tp_size: int
1232+
):
1233+
"""
1234+
Validate the remote agent handshake metadata ensuring the
1235+
invariants hold true.
1236+
"""
1237+
remote_engine_id = nixl_agent_meta.engine_id
1238+
1239+
assert self._tp_size[remote_engine_id] == remote_tp_size
1240+
# TODO We may eventually want to skip enforcing the same attn backend.
1241+
assert nixl_agent_meta.attn_backend_name == self.backend_name
1242+
1243+
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
10971244
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
10981245
assert not self._use_pallas or tp_ratio == 1, (
10991246
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
11001247
)
1101-
1102-
# Handle tp_size>num_kv_heads: replicate KV cache.
1103-
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
1104-
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
1105-
1106-
remote_block_len = nixl_agent_meta.block_lens[0]
1107-
if nixl_agent_meta.kv_cache_layout != self.kv_cache_layout:
1248+
if not self.use_mla and nixl_agent_meta.kv_cache_layout != self.kv_cache_layout:
11081249
if (
1109-
self.vllm_config.kv_transfer_config is not None
1110-
and self.vllm_config.kv_transfer_config.enable_permute_local_kv
1250+
self.kv_transfer_config.enable_permute_local_kv
11111251
and nixl_agent_meta.kv_cache_layout == "HND"
11121252
):
11131253
logger.info(
@@ -1121,13 +1261,19 @@ def add_remote_agent(
11211261
"Or enable experimental feature to use HND to NHD support by "
11221262
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
11231263
)
1124-
if self.use_mla or is_kv_replicated:
1264+
1265+
# Block len can only vary across layers when using MLA.
1266+
remote_block_len = nixl_agent_meta.block_lens[0]
1267+
if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id):
11251268
# With replicated KV cache, only the number of blocks can differ.
11261269
assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
11271270
"KV cache sizes must match between P and D when replicated"
11281271
)
11291272
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
11301273
else:
1274+
if tp_ratio > 1 and self.device_type == "xpu":
1275+
# XPU uses NHD, hence it does not support splitting on H
1276+
raise ValueError("Heterogeneous TP is not supported on XPU")
11311277
# When MLA is not used, this is a list of the same block length
11321278
for block_len in nixl_agent_meta.block_lens:
11331279
assert block_len == remote_block_len, (
@@ -1139,14 +1285,6 @@ def add_remote_agent(
11391285
if self._use_flashinfer:
11401286
# With flashinfer, KV are sent in the same message.
11411287
remote_block_size //= 2
1142-
if tp_ratio > 1:
1143-
# Heterogeneous TP expects same kv_cache_layout.
1144-
if nixl_agent_meta.kv_cache_layout == "NHD":
1145-
raise ValueError(
1146-
"Heterogeneous TP is not supported for remote with NHD."
1147-
)
1148-
if self.device_type == "xpu":
1149-
raise ValueError("Heterogeneous TP is not supported on XPU")
11501288

11511289
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
11521290
"Remote P worker KV layer cache must be of shape [2, N, "
@@ -1158,60 +1296,10 @@ def add_remote_agent(
11581296
f"{self.block_size=}, {remote_block_size=}"
11591297
)
11601298

1161-
# Create dst descs and xfer side handles. TP workers have same #blocks.
1162-
if engine_id in self.dst_num_blocks:
1163-
assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks
1164-
else:
1165-
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
1166-
1167-
blocks_data = []
1168-
# With homogeneous TP, D pulls the whole kv cache from corresponding
1169-
# rank. With heterogeneous TP, prepare the descriptors by splitting the
1170-
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
1171-
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
1172-
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr
1299+
# TP workers have same #blocks.
1300+
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
11731301

11741302
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
1175-
# Register all remote blocks, but only the corresponding kv heads.
1176-
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
1177-
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
1178-
rank_offset = (
1179-
self.tp_rank % tp_ratio * kv_block_len
1180-
if not (self.use_mla or is_kv_replicated)
1181-
else 0
1182-
)
1183-
for block_id in range(nixl_agent_meta.num_blocks):
1184-
block_offset = block_id * nixl_agent_meta.block_lens[i]
1185-
# For each block, grab the heads chunk belonging to rank_i
1186-
# of size remote_nheads // tp_ratio, which correspond to
1187-
# self.block_len == remote_block_len//tp_ratio bytes.
1188-
addr = base_addr + block_offset + rank_offset
1189-
# (addr, len, device id)
1190-
blocks_data.append((addr, kv_block_len, remote_tp_rank))
1191-
1192-
if self._use_flashinfer:
1193-
# With FlashInfer index V separately to allow head splitting.
1194-
for block_id in range(nixl_agent_meta.num_blocks):
1195-
block_offset = block_id * nixl_agent_meta.block_lens[i]
1196-
addr = base_addr + block_offset + rank_offset
1197-
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
1198-
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
1199-
1200-
logger.debug(
1201-
"Created %s blocks for dst engine %s with remote rank %s and local rank %s",
1202-
len(blocks_data),
1203-
engine_id,
1204-
remote_tp_rank,
1205-
self.tp_rank,
1206-
)
1207-
1208-
# Register with NIXL.
1209-
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
1210-
self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist(
1211-
remote_agent_name, descs
1212-
)
1213-
1214-
return remote_agent_name
12151303

12161304
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
12171305
"""copy recved kv from host buffer to device."""
@@ -1505,14 +1593,16 @@ def _read_blocks(
15051593

15061594
# Number of D TP workers that will read from dst P. Propagate tp_ratio
15071595
# on notification so that dst worker can wait before freeing blocks.
1508-
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id]
1596+
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id)
15091597
notif_id = f"{request_id}:{tp_ratio}".encode()
15101598

15111599
# Full prefix cache hit: do not need to read remote blocks,
15121600
# just notify P worker that we have the blocks we need.
15131601
num_local_blocks = len(local_block_ids)
15141602
if num_local_blocks == 0:
1515-
remote_rank = self.tp_rank // tp_ratio
1603+
remote_rank = self.kv_topo.get_target_remote_rank_from_engine_id(
1604+
dst_engine_id
1605+
)
15161606
agent_name = self._remote_agents[dst_engine_id][remote_rank]
15171607
try:
15181608
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)

0 commit comments

Comments
 (0)