|
30 | 30 | from vllm.logger import init_logger
|
31 | 31 | from vllm.platforms import _Backend, current_platform
|
32 | 32 | from vllm.utils import make_zmq_path, make_zmq_socket
|
| 33 | +from vllm.v1.attention.backends.utils import get_kv_cache_layout |
33 | 34 | from vllm.v1.core.sched.output import SchedulerOutput
|
34 | 35 | from vllm.v1.request import RequestStatus
|
35 | 36 |
|
@@ -73,6 +74,7 @@ class NixlAgentMetadata(
|
73 | 74 | num_blocks: int
|
74 | 75 | block_len: int
|
75 | 76 | attn_backend_name: str
|
| 77 | + kv_cache_layout: str |
76 | 78 |
|
77 | 79 |
|
78 | 80 | @dataclass
|
@@ -538,7 +540,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
538 | 540 | attn_backend = backend_name_to_enum(self.backend_name)
|
539 | 541 | self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
|
540 | 542 | self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1
|
| 543 | + self.kv_cache_layout = get_kv_cache_layout() |
541 | 544 | logger.debug("Detected attention backend %s", self.backend_name)
|
| 545 | + logger.debug("Detected kv cache layout %s", self.kv_cache_layout) |
542 | 546 |
|
543 | 547 | self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
|
544 | 548 | # With heterogeneous TP, P must wait for all assigned D TP workers to
|
@@ -839,7 +843,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
839 | 843 | kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
840 | 844 | num_blocks=self.num_blocks,
|
841 | 845 | block_len=self.block_len,
|
842 |
| - attn_backend_name=self.backend_name) |
| 846 | + attn_backend_name=self.backend_name, |
| 847 | + kv_cache_layout=self.kv_cache_layout) |
843 | 848 | ready_event = threading.Event()
|
844 | 849 | self._nixl_handshake_listener_t = threading.Thread(
|
845 | 850 | target=self._nixl_handshake_listener,
|
@@ -900,8 +905,7 @@ def add_remote_agent(self,
|
900 | 905 | self._tp_size[engine_id] = remote_tp_size
|
901 | 906 | else:
|
902 | 907 | assert self._tp_size[engine_id] == remote_tp_size
|
903 |
| - # We may eventually enable this after asserting equality in cache |
904 |
| - # layout and close outputs. |
| 908 | + # TODO We may eventually want to skip enforcing the same attn backend. |
905 | 909 | assert nixl_agent_meta.attn_backend_name == self.backend_name
|
906 | 910 |
|
907 | 911 | remote_agent_name = self.nixl_wrapper.add_remote_agent(
|
@@ -930,6 +934,9 @@ def add_remote_agent(self,
|
930 | 934 | if self._use_flashinfer:
|
931 | 935 | # Account for joint KV in FlashInfer.
|
932 | 936 | remote_block_size //= 2
|
| 937 | + if tp_ratio > 1: |
| 938 | + # Heterogeneous TP expects same kv_cache_layout. |
| 939 | + assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout |
933 | 940 |
|
934 | 941 | assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
|
935 | 942 | "Remote P worker KV layer cache must be of shape [2, N, "
|
|
0 commit comments