|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +import math |
| 4 | +import threading |
| 5 | +import torch |
| 6 | +from vllm.utils import make_zmq_path, make_zmq_socket, round_down |
| 7 | +from vllm_gaudi.extension.logger import logger as init_logger |
| 8 | +from vllm.distributed.kv_transfer.kv_connector.v1 import ( |
| 9 | + nixl_connector) |
| 10 | +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( |
| 11 | + NixlConnectorWorker, NixlAgentMetadata) |
| 12 | + |
| 13 | +logger = init_logger() |
| 14 | + |
| 15 | +nixl_connector._NIXL_SUPPORTED_XPUS = { |
| 16 | + "cuda": ("cuda", ), |
| 17 | + "tpu": ("cpu", ), |
| 18 | + "hpu": ("cpu", ) |
| 19 | +} |
| 20 | + |
| 21 | +def initialize_host_xfer_buffer( |
| 22 | + self, kv_caches: dict[str, torch.Tensor]) -> None: |
| 23 | + """ |
| 24 | + Initialize transfer buffer in CPU mem for accelerators |
| 25 | + NOT directly supported by NIXL (e.g., tpu) |
| 26 | + """ |
| 27 | + xfer_buffers: dict[str, torch.Tensor] = {} |
| 28 | + try: |
| 29 | + for layer_name, kv_cache in kv_caches.items(): |
| 30 | + if self.device_type == "hpu": |
| 31 | + kv_shape = (2, *kv_cache[0].shape) |
| 32 | + kv_dtype = kv_cache[0].dtype |
| 33 | + xfer_buffers[layer_name] = torch.empty(kv_shape, |
| 34 | + dtype=kv_dtype, |
| 35 | + device="cpu") |
| 36 | + else: |
| 37 | + kv_shape = kv_cache.shape |
| 38 | + kv_dtype = kv_cache.dtype |
| 39 | + xfer_buffers[layer_name] = torch.empty(kv_shape, |
| 40 | + dtype=kv_dtype, |
| 41 | + device="cpu") |
| 42 | + except MemoryError as e: |
| 43 | + logger.error("NIXLConnectorWorker gets %s.", e) |
| 44 | + raise |
| 45 | + |
| 46 | + self.host_xfer_buffers = xfer_buffers |
| 47 | + |
| 48 | + |
| 49 | +def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): |
| 50 | + """Register the KV Cache data in nixl.""" |
| 51 | + |
| 52 | + _, first_kv_cache = next(iter(kv_caches.items())) |
| 53 | + if self.device_type == "hpu": |
| 54 | + kv_elem_size = first_kv_cache[0][0].dtype.itemsize |
| 55 | + else: |
| 56 | + kv_elem_size = first_kv_cache.element_size() |
| 57 | + |
| 58 | + if self.use_host_buffer: |
| 59 | + self.initialize_host_xfer_buffer(kv_caches=kv_caches) |
| 60 | + assert len(self.host_xfer_buffers) == len(kv_caches), ( |
| 61 | + f"host_buffer: {len(self.host_xfer_buffers)}, " |
| 62 | + f"kv_caches: {len(kv_caches)}") |
| 63 | + xfer_buffers = self.host_xfer_buffers |
| 64 | + else: |
| 65 | + xfer_buffers = kv_caches |
| 66 | + assert not self.host_xfer_buffers, ( |
| 67 | + "host_xfer_buffer should not be initialized when " |
| 68 | + f"kv_buffer_device is {self.kv_buffer_device}") |
| 69 | + |
| 70 | + # TODO(tms): Find a more robust way to detect and handle MLA |
| 71 | + # NOTE (NickLucche) To move blocks efficiently with NIXL, the expected |
| 72 | + # KV memory layout is HND, as opposed to the default NHD. Note that it |
| 73 | + # will only affects the strides. For MLA instead, we make require no |
| 74 | + # such thing and resort to the standard layout. |
| 75 | + use_mla = len(first_kv_cache.shape) == 3 if self.device_type != "hpu" else False |
| 76 | + if self.device_type == "tpu": |
| 77 | + assert not use_mla, f"{self.kv_buffer_device} does not support MLA." |
| 78 | + assert self._use_pallas_v1, f"attn backend: {self.backend_name}" |
| 79 | + # tpu (v1) kv shape per layer: |
| 80 | + # (num_blocks, block_size, num_kv_heads * 2, head_size) |
| 81 | + self.num_blocks = first_kv_cache.shape[0] |
| 82 | + block_rank = 3 # [block_size, kv_heads, head_dim] |
| 83 | + block_shape = first_kv_cache.shape[-block_rank:] |
| 84 | + block_size, n_kv_heads_x_2, head_dim = block_shape |
| 85 | + self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim |
| 86 | + elif self.device_type == "cuda": |
| 87 | + assert use_mla == self.use_mla |
| 88 | + # TODO (NickLucche) not compatible with hybrid allocator. |
| 89 | + # Enforce check once it goes live, as a single kv layout |
| 90 | + # is expected for xfers. |
| 91 | + if use_mla: |
| 92 | + # MLA case. |
| 93 | + self.num_blocks = first_kv_cache.shape[0] |
| 94 | + block_rank = 2 # [block_size, latent_dim] |
| 95 | + block_shape = first_kv_cache.shape[-block_rank:] |
| 96 | + block_size, kv_latent_dim = block_shape |
| 97 | + self.slot_size_bytes = kv_elem_size * kv_latent_dim |
| 98 | + else: |
| 99 | + # [2 (k and v), num_blocks, ...] |
| 100 | + if self._use_flashinfer: |
| 101 | + # FlashInfer swaps 2<->num_blocks dimensions. |
| 102 | + self.num_blocks = first_kv_cache.shape[0] |
| 103 | + block_rank = 4 # [2, block_size, kv_heads, head_dim] |
| 104 | + else: |
| 105 | + self.num_blocks = first_kv_cache.shape[1] |
| 106 | + block_rank = 3 # [block_size, kv_heads, head_dim] |
| 107 | + block_shape = first_kv_cache.shape[-block_rank:] |
| 108 | + block_size, n_kv_heads, head_dim = block_shape[-3:] |
| 109 | + # head size in bytes. |
| 110 | + self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim |
| 111 | + assert block_size == self.block_size |
| 112 | + elif self.device_type == "hpu": |
| 113 | + # habana kv_cache: [2, num_blocks*block_size, kv_heads, head_dim] |
| 114 | + #from remote_pdb import RemotePdb; RemotePdb('0.0.0.0', 4444).set_trace() |
| 115 | + self.num_blocks = first_kv_cache[0].shape[0] // self.block_size |
| 116 | + block_rank = 3 # [block_size, kv_heads, head_dim] |
| 117 | + block_shape = first_kv_cache[0].shape[-block_rank:] |
| 118 | + block_shape = list(block_shape) |
| 119 | + block_shape[0] = block_shape[0] // self.num_blocks |
| 120 | + block_shape = torch.Size(block_shape) |
| 121 | + block_size, n_kv_heads, head_dim = block_shape[-3:] |
| 122 | + # head size in bytes. |
| 123 | + self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim |
| 124 | + else: |
| 125 | + raise RuntimeError( |
| 126 | + f"{self.device_type} ({self.backend_name}) is not supported.") |
| 127 | + |
| 128 | + # TODO(tms): self.block_len needs to be per-layer for sliding window, |
| 129 | + # hybrid attn, etc |
| 130 | + # block size in bytes |
| 131 | + self.block_len = kv_elem_size * math.prod(block_shape) |
| 132 | + logger.info( |
| 133 | + "Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, " |
| 134 | + "use_host_buffer: %s, num_blocks: %s, block_shape: %s, " |
| 135 | + "per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device, |
| 136 | + self.use_host_buffer, self.num_blocks, block_shape, |
| 137 | + first_kv_cache[0].shape) |
| 138 | + self.dst_num_blocks[self.engine_id] = self.num_blocks |
| 139 | + self.device_kv_caches = kv_caches |
| 140 | + kv_caches_base_addr = [] |
| 141 | + caches_data = [] |
| 142 | + |
| 143 | + # Note(tms): I modified this from the original region setup code. |
| 144 | + # K and V are now in different regions. Advantage is that we can |
| 145 | + # elegantly support MLA and any cases where the K and V tensors |
| 146 | + # are non-contiguous (it's not locally guaranteed that they will be) |
| 147 | + # Disadvantage is that the encoded NixlAgentMetadata is now larger |
| 148 | + # (roughly 8KB vs 5KB). |
| 149 | + # Conversely for FlashInfer, K and V are transferred in the same tensor |
| 150 | + # to better exploit the memory layout (ie num_blocks is the first dim). |
| 151 | + for cache_or_caches in xfer_buffers.values(): |
| 152 | + # Normalize to always be a list of caches |
| 153 | + cache_list = [cache_or_caches] if use_mla \ |
| 154 | + or self._use_pallas_v1 or self._use_flashinfer \ |
| 155 | + else cache_or_caches |
| 156 | + for cache in cache_list: |
| 157 | + base_addr = cache.data_ptr() |
| 158 | + region_len = self.num_blocks * self.block_len |
| 159 | + # NOTE: use tp_rank for device_id since multi-node TP |
| 160 | + # is rarely used. |
| 161 | + caches_data.append((base_addr, region_len, self.tp_rank, "")) |
| 162 | + kv_caches_base_addr.append(base_addr) |
| 163 | + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr |
| 164 | + self.num_regions = len(caches_data) |
| 165 | + self.num_layers = len(xfer_buffers.keys()) |
| 166 | + |
| 167 | + # TODO(mgoin): remove this once we have hybrid memory allocator |
| 168 | + # Optimization for models with local attention (Llama 4) |
| 169 | + if self.vllm_config.model_config.hf_config.model_type == "llama4": |
| 170 | + from transformers import Llama4TextConfig |
| 171 | + assert isinstance(self.vllm_config.model_config.hf_text_config, |
| 172 | + Llama4TextConfig) |
| 173 | + llama4_config = self.vllm_config.model_config.hf_text_config |
| 174 | + no_rope_layers = llama4_config.no_rope_layers |
| 175 | + chunk_size = llama4_config.attention_chunk_size |
| 176 | + chunk_block_size = math.ceil(chunk_size / self.block_size) |
| 177 | + for layer_idx in range(self.num_layers): |
| 178 | + # no_rope_layers[layer_idx] == 0 means NoPE (global) |
| 179 | + # Any other value means RoPE (local chunked) |
| 180 | + is_local_attention = no_rope_layers[layer_idx] != 0 |
| 181 | + block_window = chunk_block_size if is_local_attention else None |
| 182 | + self.block_window_per_layer.append(block_window) |
| 183 | + logger.debug("Llama 4 block window per layer mapping: %s", |
| 184 | + self.block_window_per_layer) |
| 185 | + assert len(self.block_window_per_layer) == self.num_layers |
| 186 | + |
| 187 | + descs = self.nixl_wrapper.get_reg_descs(caches_data, |
| 188 | + self.nixl_memory_type) |
| 189 | + logger.debug("Registering descs: %s", caches_data) |
| 190 | + self.nixl_wrapper.register_memory(descs) |
| 191 | + logger.debug("Done registering descs") |
| 192 | + self._registered_descs.append(descs) |
| 193 | + |
| 194 | + # Register local/src descr for NIXL xfer. |
| 195 | + blocks_data = [] |
| 196 | + for base_addr in self.kv_caches_base_addr[self.engine_id]: |
| 197 | + # NOTE With heter-TP, more blocks are prepared than what are |
| 198 | + # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We |
| 199 | + # could create fewer, but then _get_block_descs_ids needs to |
| 200 | + # select agent_meta.num_blocks instead of self.num_blocks for |
| 201 | + # local descr, and that makes handling regular flow less clean. |
| 202 | + for block_id in range(self.num_blocks): |
| 203 | + block_offset = block_id * self.block_len |
| 204 | + addr = base_addr + block_offset |
| 205 | + # (addr, len, device id) |
| 206 | + # TODO: does device_id matter to DRAM? |
| 207 | + blocks_data.append((addr, self.block_len, self.tp_rank)) |
| 208 | + logger.debug("Created %s blocks for src engine %s and rank %s", |
| 209 | + len(blocks_data), self.engine_id, self.tp_rank) |
| 210 | + |
| 211 | + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, |
| 212 | + self.nixl_memory_type) |
| 213 | + # NIXL_INIT_AGENT to be used for preparations of local descs. |
| 214 | + self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( |
| 215 | + "NIXL_INIT_AGENT", descs) |
| 216 | + |
| 217 | + # After KV Caches registered, listen for new connections. |
| 218 | + metadata = NixlAgentMetadata( |
| 219 | + engine_id=self.engine_id, |
| 220 | + agent_metadata=self.nixl_wrapper.get_agent_metadata(), |
| 221 | + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], |
| 222 | + num_blocks=self.num_blocks, |
| 223 | + block_len=self.block_len, |
| 224 | + attn_backend_name=self.backend_name, |
| 225 | + kv_cache_layout=self.kv_cache_layout) |
| 226 | + ready_event = threading.Event() |
| 227 | + self._nixl_handshake_listener_t = threading.Thread( |
| 228 | + target=self._nixl_handshake_listener, |
| 229 | + args=(metadata, ready_event, self.side_channel_port, self.tp_rank), |
| 230 | + daemon=True, |
| 231 | + name="nixl_handshake_listener") |
| 232 | + self._nixl_handshake_listener_t.start() |
| 233 | + ready_event.wait() # Wait for listener ZMQ socket to be ready. |
| 234 | + |
| 235 | +nixl_connector.NixlConnectorWorker.initialize_host_xfer_buffer = initialize_host_xfer_buffer |
| 236 | +nixl_connector.NixlConnectorWorker.register_kv_caches = register_kv_caches |
| 237 | + |
0 commit comments