3636 get_tensor_model_parallel_world_size ,
3737 get_tp_group ,
3838)
39- from vllm .distributed .utils import divide
4039from vllm .forward_context import ForwardContext
4140from vllm .logger import init_logger
4241from vllm .platforms import current_platform
@@ -521,6 +520,72 @@ def request_finished(
521520class NixlConnectorWorker :
522521 """Implementation of Worker side methods"""
523522
523+ @dataclass
524+ class TpKVTopology :
525+ """
526+ Helper class for tensor parallel and KV topology information for
527+ mapping between local and remote TP workers.
528+ """
529+
530+ tp_size : int
531+ tp_rank : int
532+ remote_tp_size : dict [EngineId , int ]
533+ is_mla : bool
534+ total_num_kv_heads : int
535+
536+ def tp_ratio (
537+ self ,
538+ remote_tp_size : int ,
539+ ) -> int :
540+ """
541+ Calculate the tensor parallel ratio between local and remote TP.
542+ We can think of it as the number of local TP workers-per-remote TP
543+ workers. Local workers will read from the same remote TP worker in
544+ groups of size `tp_ratio`.
545+ """
546+ assert self .tp_size % remote_tp_size == 0 , (
547+ f"Local tensor parallel size { self .tp_size } is not divisible "
548+ f"by remote tensor parallel size { remote_tp_size } ."
549+ )
550+ return self .tp_size // remote_tp_size
551+
552+ def tp_ratio_from_engine_id (
553+ self ,
554+ remote_engine_id : EngineId ,
555+ ) -> int :
556+ remote_tp_size = self .remote_tp_size [remote_engine_id ]
557+ return self .tp_ratio (remote_tp_size )
558+
559+ def is_kv_replicated (self , engine_id : EngineId ) -> bool :
560+ """
561+ Whether the KV cache is replicated across TP workers due to the
562+ number of TP workers being greater than the number of KV heads.
563+ """
564+ tp_size = self .remote_tp_size [engine_id ]
565+ return tp_size // self .total_num_kv_heads >= 1
566+
567+ def replicates_kv_cache (self , remote_engine_id : EngineId ) -> bool :
568+ # MLA is always replicated as the hidden dim can't be split.
569+ return self .is_mla or self .is_kv_replicated (remote_engine_id )
570+
571+ def get_target_remote_rank (
572+ self ,
573+ remote_tp_size : int ,
574+ ) -> int :
575+ """
576+ Get the remote TP rank (on P) that the current local TP rank
577+ (on D) will read from.
578+ """
579+ tp_ratio = self .tp_ratio (remote_tp_size )
580+ return self .tp_rank // tp_ratio
581+
582+ def get_target_remote_rank_from_engine_id (
583+ self ,
584+ remote_engine_id : EngineId ,
585+ ) -> int :
586+ remote_tp_size = self .remote_tp_size [remote_engine_id ]
587+ return self .get_target_remote_rank (remote_tp_size )
588+
524589 def __init__ (self , vllm_config : VllmConfig , engine_id : str ):
525590 if NixlWrapper is None :
526591 logger .error ("NIXL is not available" )
@@ -534,6 +599,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
534599
535600 if vllm_config .kv_transfer_config is None :
536601 raise ValueError ("kv_transfer_config must be set for NixlConnector" )
602+ self .kv_transfer_config = vllm_config .kv_transfer_config
537603
538604 self .nixl_backends = vllm_config .kv_transfer_config .get_from_extra_config (
539605 "backends" , ["UCX" ]
@@ -654,7 +720,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
654720 # Protects _handshake_futures and _remote_agents.
655721 self ._handshake_lock = threading .RLock ()
656722
657- self .vllm_config = vllm_config
658723 self .block_size = vllm_config .cache_config .block_size
659724 self .model_config = vllm_config .model_config
660725 self .cache_config = vllm_config .cache_config
@@ -686,6 +751,14 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
686751 self .consumer_notification_counts_by_req = defaultdict [ReqId , int ](int )
687752 self .xfer_stats = NixlKVConnectorStats ()
688753
754+ self .kv_topo = self .TpKVTopology (
755+ tp_size = self .world_size ,
756+ tp_rank = self .tp_rank ,
757+ remote_tp_size = self ._tp_size , # shared state
758+ is_mla = self .use_mla ,
759+ total_num_kv_heads = self .model_config .get_total_num_kv_heads (),
760+ )
761+
689762 @staticmethod
690763 def _nixl_handshake_listener (
691764 metadata : NixlAgentMetadata ,
@@ -731,8 +804,7 @@ def _nixl_handshake(
731804
732805 # Handshake only with the remote TP rank that current local rank will
733806 # pull from. With homogeneous TP it happens to be the same rank_i.
734- tp_ratio = self ._tp_size [self .engine_id ] // remote_tp_size
735- p_remote_rank = self .tp_rank // tp_ratio
807+ p_remote_rank = self .kv_topo .get_target_remote_rank (remote_tp_size )
736808 path = make_zmq_path ("tcp" , host , port + p_remote_rank )
737809 logger .debug (
738810 "Querying metadata on path: %s at remote rank %s" , path , p_remote_rank
@@ -989,13 +1061,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
9891061
9901062 # TODO(mgoin): Hybrid memory allocator is currently disabled for
9911063 # models with local attention (Llama 4). Can remove this once enabled.
992- if self .vllm_config . model_config .hf_config .model_type == "llama4" :
1064+ if self .model_config .hf_config .model_type == "llama4" :
9931065 from transformers import Llama4TextConfig
9941066
995- assert isinstance (
996- self .vllm_config .model_config .hf_text_config , Llama4TextConfig
997- )
998- llama4_config = self .vllm_config .model_config .hf_text_config
1067+ assert isinstance (self .model_config .hf_text_config , Llama4TextConfig )
1068+ llama4_config = self .model_config .hf_text_config
9991069 no_rope_layers = llama4_config .no_rope_layers
10001070 chunk_size = llama4_config .attention_chunk_size
10011071 chunk_block_size = math .ceil (chunk_size / self .block_size )
@@ -1078,36 +1148,106 @@ def add_remote_agent(
10781148 engine_id = nixl_agent_meta .engine_id
10791149 # TODO re-evaluate refreshing for scaling/recovery
10801150 if remote_tp_rank in self ._remote_agents .get (engine_id , {}):
1151+ logger .debug (
1152+ "Remote agent with engine_id %s and rank"
1153+ "%s already exchanged metadata, skip handshake." ,
1154+ engine_id ,
1155+ remote_tp_rank ,
1156+ )
10811157 return self ._remote_agents [engine_id ][remote_tp_rank ]
10821158
1159+ ### Register remote agent metadata
10831160 if engine_id not in self ._tp_size :
10841161 self ._tp_size [engine_id ] = remote_tp_size
1085- else :
1086- assert self ._tp_size [engine_id ] == remote_tp_size
1087- # TODO We may eventually want to skip enforcing the same attn backend.
1088- assert nixl_agent_meta .attn_backend_name == self .backend_name
10891162
10901163 remote_agent_name = self .nixl_wrapper .add_remote_agent (
10911164 nixl_agent_meta .agent_metadata
10921165 )
10931166
1167+ # Handle tp_size>num_kv_heads: replicate KV cache.
1168+ replicates_kv_cache = self .kv_topo .replicates_kv_cache (engine_id )
1169+
1170+ # Create dst descs and xfer side handles. TP workers have same #blocks
1171+ # so we only register once per engine_id.
1172+ if engine_id not in self .dst_num_blocks :
1173+ self .dst_num_blocks [engine_id ] = nixl_agent_meta .num_blocks
1174+
1175+ # Keep track of remote agent kv caches base addresses.
1176+ self .kv_caches_base_addr [engine_id ] = nixl_agent_meta .kv_caches_base_addr
1177+
1178+ self ._validate_remote_agent_handshake (nixl_agent_meta , remote_tp_size )
1179+
10941180 # Number of D TP workers reading from a single P TP worker. This is
10951181 # 1 when P and D `--tensor-parallel-size` match.
1096- tp_ratio = divide (self ._tp_size [self .engine_id ], self ._tp_size [engine_id ])
1182+ tp_ratio = self .kv_topo .tp_ratio_from_engine_id (engine_id )
1183+
1184+ ### Register remote agent memory regions
1185+ blocks_data = []
1186+ # With homogeneous TP, D pulls the whole kv cache from corresponding
1187+ # rank. With heterogeneous TP, prepare the descriptors by splitting the
1188+ # P KV cache along kv_head dim, of D worker's kv_head size (D>P).
1189+ # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
1190+
1191+ # Register all remote blocks, but only the corresponding kv heads.
1192+ for i , base_addr in enumerate (nixl_agent_meta .kv_caches_base_addr ):
1193+ kv_block_len = self .get_backend_aware_kv_block_len (layer_idx = i )
1194+ rank_offset = (
1195+ self .tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0
1196+ )
1197+ for block_id in range (nixl_agent_meta .num_blocks ):
1198+ block_offset = block_id * nixl_agent_meta .block_lens [i ]
1199+ # For each block, grab the heads chunk belonging to rank_i
1200+ # of size remote_nheads // tp_ratio, which correspond to
1201+ # self.block_len == remote_block_len//tp_ratio bytes.
1202+ addr = base_addr + block_offset + rank_offset
1203+ # (addr, len, device id)
1204+ blocks_data .append ((addr , kv_block_len , remote_tp_rank ))
1205+
1206+ if self ._use_flashinfer :
1207+ # With FlashInfer index V separately to allow head splitting.
1208+ for block_id in range (nixl_agent_meta .num_blocks ):
1209+ block_offset = block_id * nixl_agent_meta .block_lens [i ]
1210+ addr = base_addr + block_offset + rank_offset
1211+ v_addr = addr + nixl_agent_meta .block_lens [i ] // 2
1212+ blocks_data .append ((v_addr , kv_block_len , remote_tp_rank ))
1213+
1214+ logger .debug (
1215+ "Created %s blocks for dst engine %s with remote rank %s and local rank %s" ,
1216+ len (blocks_data ),
1217+ engine_id ,
1218+ remote_tp_rank ,
1219+ self .tp_rank ,
1220+ )
1221+
1222+ # Register with NIXL.
1223+ descs = self .nixl_wrapper .get_xfer_descs (blocks_data , self .nixl_memory_type )
1224+ self .dst_xfer_side_handles [engine_id ] = self .nixl_wrapper .prep_xfer_dlist (
1225+ remote_agent_name , descs
1226+ )
1227+
1228+ return remote_agent_name
1229+
1230+ def _validate_remote_agent_handshake (
1231+ self , nixl_agent_meta : NixlAgentMetadata , remote_tp_size : int
1232+ ):
1233+ """
1234+ Validate the remote agent handshake metadata ensuring the
1235+ invariants hold true.
1236+ """
1237+ remote_engine_id = nixl_agent_meta .engine_id
1238+
1239+ assert self ._tp_size [remote_engine_id ] == remote_tp_size
1240+ # TODO We may eventually want to skip enforcing the same attn backend.
1241+ assert nixl_agent_meta .attn_backend_name == self .backend_name
1242+
1243+ tp_ratio = self .kv_topo .tp_ratio_from_engine_id (remote_engine_id )
10971244 assert tp_ratio > 0 , "Decode TP cannot be smaller than prefill TP"
10981245 assert not self ._use_pallas or tp_ratio == 1 , (
10991246 "TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
11001247 )
1101-
1102- # Handle tp_size>num_kv_heads: replicate KV cache.
1103- total_num_kv_heads = self .model_config .get_total_num_kv_heads ()
1104- is_kv_replicated = self ._tp_size [engine_id ] // total_num_kv_heads >= 1
1105-
1106- remote_block_len = nixl_agent_meta .block_lens [0 ]
1107- if nixl_agent_meta .kv_cache_layout != self .kv_cache_layout :
1248+ if not self .use_mla and nixl_agent_meta .kv_cache_layout != self .kv_cache_layout :
11081249 if (
1109- self .vllm_config .kv_transfer_config is not None
1110- and self .vllm_config .kv_transfer_config .enable_permute_local_kv
1250+ self .kv_transfer_config .enable_permute_local_kv
11111251 and nixl_agent_meta .kv_cache_layout == "HND"
11121252 ):
11131253 logger .info (
@@ -1121,13 +1261,19 @@ def add_remote_agent(
11211261 "Or enable experimental feature to use HND to NHD support by "
11221262 "setting 'enable_permute_local_kv'=True in --kv-transfer-config."
11231263 )
1124- if self .use_mla or is_kv_replicated :
1264+
1265+ # Block len can only vary across layers when using MLA.
1266+ remote_block_len = nixl_agent_meta .block_lens [0 ]
1267+ if self .use_mla or self .kv_topo .is_kv_replicated (remote_engine_id ):
11251268 # With replicated KV cache, only the number of blocks can differ.
11261269 assert self .block_len_per_layer == nixl_agent_meta .block_lens , (
11271270 "KV cache sizes must match between P and D when replicated"
11281271 )
11291272 remote_block_size = remote_block_len // (self .slot_size_per_layer [0 ])
11301273 else :
1274+ if tp_ratio > 1 and self .device_type == "xpu" :
1275+ # XPU uses NHD, hence it does not support splitting on H
1276+ raise ValueError ("Heterogeneous TP is not supported on XPU" )
11311277 # When MLA is not used, this is a list of the same block length
11321278 for block_len in nixl_agent_meta .block_lens :
11331279 assert block_len == remote_block_len , (
@@ -1139,14 +1285,6 @@ def add_remote_agent(
11391285 if self ._use_flashinfer :
11401286 # With flashinfer, KV are sent in the same message.
11411287 remote_block_size //= 2
1142- if tp_ratio > 1 :
1143- # Heterogeneous TP expects same kv_cache_layout.
1144- if nixl_agent_meta .kv_cache_layout == "NHD" :
1145- raise ValueError (
1146- "Heterogeneous TP is not supported for remote with NHD."
1147- )
1148- if self .device_type == "xpu" :
1149- raise ValueError ("Heterogeneous TP is not supported on XPU" )
11501288
11511289 assert remote_block_len == self .block_len_per_layer [0 ] * tp_ratio , (
11521290 "Remote P worker KV layer cache must be of shape [2, N, "
@@ -1158,60 +1296,10 @@ def add_remote_agent(
11581296 f"{ self .block_size = } , { remote_block_size = } "
11591297 )
11601298
1161- # Create dst descs and xfer side handles. TP workers have same #blocks.
1162- if engine_id in self .dst_num_blocks :
1163- assert self .dst_num_blocks [engine_id ] == nixl_agent_meta .num_blocks
1164- else :
1165- self .dst_num_blocks [engine_id ] = nixl_agent_meta .num_blocks
1166-
1167- blocks_data = []
1168- # With homogeneous TP, D pulls the whole kv cache from corresponding
1169- # rank. With heterogeneous TP, prepare the descriptors by splitting the
1170- # P KV cache along kv_head dim, of D worker's kv_head size (D>P).
1171- # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
1172- self .kv_caches_base_addr [engine_id ] = nixl_agent_meta .kv_caches_base_addr
1299+ # TP workers have same #blocks.
1300+ assert self .dst_num_blocks [remote_engine_id ] == nixl_agent_meta .num_blocks
11731301
11741302 assert len (nixl_agent_meta .kv_caches_base_addr ) == len (self .block_len_per_layer )
1175- # Register all remote blocks, but only the corresponding kv heads.
1176- for i , base_addr in enumerate (nixl_agent_meta .kv_caches_base_addr ):
1177- kv_block_len = self .get_backend_aware_kv_block_len (layer_idx = i )
1178- rank_offset = (
1179- self .tp_rank % tp_ratio * kv_block_len
1180- if not (self .use_mla or is_kv_replicated )
1181- else 0
1182- )
1183- for block_id in range (nixl_agent_meta .num_blocks ):
1184- block_offset = block_id * nixl_agent_meta .block_lens [i ]
1185- # For each block, grab the heads chunk belonging to rank_i
1186- # of size remote_nheads // tp_ratio, which correspond to
1187- # self.block_len == remote_block_len//tp_ratio bytes.
1188- addr = base_addr + block_offset + rank_offset
1189- # (addr, len, device id)
1190- blocks_data .append ((addr , kv_block_len , remote_tp_rank ))
1191-
1192- if self ._use_flashinfer :
1193- # With FlashInfer index V separately to allow head splitting.
1194- for block_id in range (nixl_agent_meta .num_blocks ):
1195- block_offset = block_id * nixl_agent_meta .block_lens [i ]
1196- addr = base_addr + block_offset + rank_offset
1197- v_addr = addr + nixl_agent_meta .block_lens [i ] // 2
1198- blocks_data .append ((v_addr , kv_block_len , remote_tp_rank ))
1199-
1200- logger .debug (
1201- "Created %s blocks for dst engine %s with remote rank %s and local rank %s" ,
1202- len (blocks_data ),
1203- engine_id ,
1204- remote_tp_rank ,
1205- self .tp_rank ,
1206- )
1207-
1208- # Register with NIXL.
1209- descs = self .nixl_wrapper .get_xfer_descs (blocks_data , self .nixl_memory_type )
1210- self .dst_xfer_side_handles [engine_id ] = self .nixl_wrapper .prep_xfer_dlist (
1211- remote_agent_name , descs
1212- )
1213-
1214- return remote_agent_name
12151303
12161304 def sync_recved_kv_to_device (self , req_id : str , meta : ReqMeta ):
12171305 """copy recved kv from host buffer to device."""
@@ -1505,14 +1593,16 @@ def _read_blocks(
15051593
15061594 # Number of D TP workers that will read from dst P. Propagate tp_ratio
15071595 # on notification so that dst worker can wait before freeing blocks.
1508- tp_ratio = self ._tp_size [ self . engine_id ] // self . _tp_size [ dst_engine_id ]
1596+ tp_ratio = self .kv_topo . tp_ratio_from_engine_id ( dst_engine_id )
15091597 notif_id = f"{ request_id } :{ tp_ratio } " .encode ()
15101598
15111599 # Full prefix cache hit: do not need to read remote blocks,
15121600 # just notify P worker that we have the blocks we need.
15131601 num_local_blocks = len (local_block_ids )
15141602 if num_local_blocks == 0 :
1515- remote_rank = self .tp_rank // tp_ratio
1603+ remote_rank = self .kv_topo .get_target_remote_rank_from_engine_id (
1604+ dst_engine_id
1605+ )
15161606 agent_name = self ._remote_agents [dst_engine_id ][remote_rank ]
15171607 try :
15181608 self .nixl_wrapper .send_notif (agent_name , notif_msg = notif_id )
0 commit comments