Skip to content

Commit 9322571

Browse files
committed
enable nixl connector for hpu
1 parent a18e09b commit 9322571

File tree

9 files changed

+304
-16
lines changed

9 files changed

+304
-16
lines changed

vllm_gaudi/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ def register():
1414
def register_ops():
1515
"""Register custom ops for the HPU platform."""
1616
import vllm_gaudi.ops # noqa: F401
17+
import vllm_gaudi.distributed.kv_transfer.kv_connector.v1.hpu_nixl_connector # noqa: F401

vllm_gaudi/distributed/kv_transfer/__init__.py

Whitespace-only changes.

vllm_gaudi/distributed/kv_transfer/kv_connector/__init__.py

Whitespace-only changes.

vllm_gaudi/distributed/kv_transfer/kv_connector/v1/__init__.py

Whitespace-only changes.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
4+
import torch
5+
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
6+
KVConnectorBase_V1, CopyBlocksOp)
7+
from vllm_gaudi.extension.logger import logger as init_logger
8+
9+
logger = init_logger()
10+
11+
class KVTransferParams:
12+
"""
13+
Abstract KVTransferParams used to send KVTransfer
14+
parameters between instances of vLLM.
15+
Specific instances of KVConnector customize this
16+
method for serializing / deserializing msgs sent
17+
via the HTTP protocol.
18+
"""
19+
20+
@staticmethod
21+
def from_raw_dict(
22+
raw_dict: Optional[dict[str,
23+
Any]]) -> Optional["KVTransferParams"]:
24+
return None
25+
26+
27+
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
28+
"""
29+
Set the xPU-specific ops for copying KV between host and device.
30+
Needed when host buffer is used for kv transfer (e.g., in NixlConnector)
31+
"""
32+
return
33+
34+
35+
# ==============================
36+
# Scheduler-side methods
37+
# ==============================
38+
39+
def set_kv_transfer_params(self, request: "Request"):
40+
_KVTransferParams = KVTransferParams
41+
"""Parse raw KV Transfer params."""
42+
assert request.kv_transfer_params is None
43+
kv_transfer_params = self._KVTransferParams.from_raw_dict(
44+
request.raw_kv_transfer_params)
45+
request.kv_transfer_params = kv_transfer_params
46+
47+
KVConnectorBase_V1.set_host_xfer_buffer_ops = set_host_xfer_buffer_ops
48+
KVConnectorBase_V1.set_kv_transfer_params = set_kv_transfer_params
49+
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import math
4+
import threading
5+
import torch
6+
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
7+
from vllm_gaudi.extension.logger import logger as init_logger
8+
from vllm.distributed.kv_transfer.kv_connector.v1 import (
9+
nixl_connector)
10+
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
11+
NixlConnectorWorker, NixlAgentMetadata)
12+
13+
logger = init_logger()
14+
15+
nixl_connector._NIXL_SUPPORTED_XPUS = {
16+
"cuda": ("cuda", ),
17+
"tpu": ("cpu", ),
18+
"hpu": ("cpu", )
19+
}
20+
21+
def initialize_host_xfer_buffer(
22+
self, kv_caches: dict[str, torch.Tensor]) -> None:
23+
"""
24+
Initialize transfer buffer in CPU mem for accelerators
25+
NOT directly supported by NIXL (e.g., tpu)
26+
"""
27+
xfer_buffers: dict[str, torch.Tensor] = {}
28+
try:
29+
for layer_name, kv_cache in kv_caches.items():
30+
if self.device_type == "hpu":
31+
kv_shape = (2, *kv_cache[0].shape)
32+
kv_dtype = kv_cache[0].dtype
33+
xfer_buffers[layer_name] = torch.empty(kv_shape,
34+
dtype=kv_dtype,
35+
device="cpu")
36+
else:
37+
kv_shape = kv_cache.shape
38+
kv_dtype = kv_cache.dtype
39+
xfer_buffers[layer_name] = torch.empty(kv_shape,
40+
dtype=kv_dtype,
41+
device="cpu")
42+
except MemoryError as e:
43+
logger.error("NIXLConnectorWorker gets %s.", e)
44+
raise
45+
46+
self.host_xfer_buffers = xfer_buffers
47+
48+
49+
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
50+
"""Register the KV Cache data in nixl."""
51+
52+
_, first_kv_cache = next(iter(kv_caches.items()))
53+
if self.device_type == "hpu":
54+
kv_elem_size = first_kv_cache[0][0].dtype.itemsize
55+
else:
56+
kv_elem_size = first_kv_cache.element_size()
57+
58+
if self.use_host_buffer:
59+
self.initialize_host_xfer_buffer(kv_caches=kv_caches)
60+
assert len(self.host_xfer_buffers) == len(kv_caches), (
61+
f"host_buffer: {len(self.host_xfer_buffers)}, "
62+
f"kv_caches: {len(kv_caches)}")
63+
xfer_buffers = self.host_xfer_buffers
64+
else:
65+
xfer_buffers = kv_caches
66+
assert not self.host_xfer_buffers, (
67+
"host_xfer_buffer should not be initialized when "
68+
f"kv_buffer_device is {self.kv_buffer_device}")
69+
70+
# TODO(tms): Find a more robust way to detect and handle MLA
71+
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
72+
# KV memory layout is HND, as opposed to the default NHD. Note that it
73+
# will only affects the strides. For MLA instead, we make require no
74+
# such thing and resort to the standard layout.
75+
use_mla = len(first_kv_cache.shape) == 3 if self.device_type != "hpu" else False
76+
if self.device_type == "tpu":
77+
assert not use_mla, f"{self.kv_buffer_device} does not support MLA."
78+
assert self._use_pallas_v1, f"attn backend: {self.backend_name}"
79+
# tpu (v1) kv shape per layer:
80+
# (num_blocks, block_size, num_kv_heads * 2, head_size)
81+
self.num_blocks = first_kv_cache.shape[0]
82+
block_rank = 3 # [block_size, kv_heads, head_dim]
83+
block_shape = first_kv_cache.shape[-block_rank:]
84+
block_size, n_kv_heads_x_2, head_dim = block_shape
85+
self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim
86+
elif self.device_type == "cuda":
87+
assert use_mla == self.use_mla
88+
# TODO (NickLucche) not compatible with hybrid allocator.
89+
# Enforce check once it goes live, as a single kv layout
90+
# is expected for xfers.
91+
if use_mla:
92+
# MLA case.
93+
self.num_blocks = first_kv_cache.shape[0]
94+
block_rank = 2 # [block_size, latent_dim]
95+
block_shape = first_kv_cache.shape[-block_rank:]
96+
block_size, kv_latent_dim = block_shape
97+
self.slot_size_bytes = kv_elem_size * kv_latent_dim
98+
else:
99+
# [2 (k and v), num_blocks, ...]
100+
if self._use_flashinfer:
101+
# FlashInfer swaps 2<->num_blocks dimensions.
102+
self.num_blocks = first_kv_cache.shape[0]
103+
block_rank = 4 # [2, block_size, kv_heads, head_dim]
104+
else:
105+
self.num_blocks = first_kv_cache.shape[1]
106+
block_rank = 3 # [block_size, kv_heads, head_dim]
107+
block_shape = first_kv_cache.shape[-block_rank:]
108+
block_size, n_kv_heads, head_dim = block_shape[-3:]
109+
# head size in bytes.
110+
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
111+
assert block_size == self.block_size
112+
elif self.device_type == "hpu":
113+
# habana kv_cache: [2, num_blocks*block_size, kv_heads, head_dim]
114+
#from remote_pdb import RemotePdb; RemotePdb('0.0.0.0', 4444).set_trace()
115+
self.num_blocks = first_kv_cache[0].shape[0] // self.block_size
116+
block_rank = 3 # [block_size, kv_heads, head_dim]
117+
block_shape = first_kv_cache[0].shape[-block_rank:]
118+
block_shape = list(block_shape)
119+
block_shape[0] = block_shape[0] // self.num_blocks
120+
block_shape = torch.Size(block_shape)
121+
block_size, n_kv_heads, head_dim = block_shape[-3:]
122+
# head size in bytes.
123+
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
124+
else:
125+
raise RuntimeError(
126+
f"{self.device_type} ({self.backend_name}) is not supported.")
127+
128+
# TODO(tms): self.block_len needs to be per-layer for sliding window,
129+
# hybrid attn, etc
130+
# block size in bytes
131+
self.block_len = kv_elem_size * math.prod(block_shape)
132+
logger.info(
133+
"Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, "
134+
"use_host_buffer: %s, num_blocks: %s, block_shape: %s, "
135+
"per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device,
136+
self.use_host_buffer, self.num_blocks, block_shape,
137+
first_kv_cache[0].shape)
138+
self.dst_num_blocks[self.engine_id] = self.num_blocks
139+
self.device_kv_caches = kv_caches
140+
kv_caches_base_addr = []
141+
caches_data = []
142+
143+
# Note(tms): I modified this from the original region setup code.
144+
# K and V are now in different regions. Advantage is that we can
145+
# elegantly support MLA and any cases where the K and V tensors
146+
# are non-contiguous (it's not locally guaranteed that they will be)
147+
# Disadvantage is that the encoded NixlAgentMetadata is now larger
148+
# (roughly 8KB vs 5KB).
149+
# Conversely for FlashInfer, K and V are transferred in the same tensor
150+
# to better exploit the memory layout (ie num_blocks is the first dim).
151+
for cache_or_caches in xfer_buffers.values():
152+
# Normalize to always be a list of caches
153+
cache_list = [cache_or_caches] if use_mla \
154+
or self._use_pallas_v1 or self._use_flashinfer \
155+
else cache_or_caches
156+
for cache in cache_list:
157+
base_addr = cache.data_ptr()
158+
region_len = self.num_blocks * self.block_len
159+
# NOTE: use tp_rank for device_id since multi-node TP
160+
# is rarely used.
161+
caches_data.append((base_addr, region_len, self.tp_rank, ""))
162+
kv_caches_base_addr.append(base_addr)
163+
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
164+
self.num_regions = len(caches_data)
165+
self.num_layers = len(xfer_buffers.keys())
166+
167+
# TODO(mgoin): remove this once we have hybrid memory allocator
168+
# Optimization for models with local attention (Llama 4)
169+
if self.vllm_config.model_config.hf_config.model_type == "llama4":
170+
from transformers import Llama4TextConfig
171+
assert isinstance(self.vllm_config.model_config.hf_text_config,
172+
Llama4TextConfig)
173+
llama4_config = self.vllm_config.model_config.hf_text_config
174+
no_rope_layers = llama4_config.no_rope_layers
175+
chunk_size = llama4_config.attention_chunk_size
176+
chunk_block_size = math.ceil(chunk_size / self.block_size)
177+
for layer_idx in range(self.num_layers):
178+
# no_rope_layers[layer_idx] == 0 means NoPE (global)
179+
# Any other value means RoPE (local chunked)
180+
is_local_attention = no_rope_layers[layer_idx] != 0
181+
block_window = chunk_block_size if is_local_attention else None
182+
self.block_window_per_layer.append(block_window)
183+
logger.debug("Llama 4 block window per layer mapping: %s",
184+
self.block_window_per_layer)
185+
assert len(self.block_window_per_layer) == self.num_layers
186+
187+
descs = self.nixl_wrapper.get_reg_descs(caches_data,
188+
self.nixl_memory_type)
189+
logger.debug("Registering descs: %s", caches_data)
190+
self.nixl_wrapper.register_memory(descs)
191+
logger.debug("Done registering descs")
192+
self._registered_descs.append(descs)
193+
194+
# Register local/src descr for NIXL xfer.
195+
blocks_data = []
196+
for base_addr in self.kv_caches_base_addr[self.engine_id]:
197+
# NOTE With heter-TP, more blocks are prepared than what are
198+
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
199+
# could create fewer, but then _get_block_descs_ids needs to
200+
# select agent_meta.num_blocks instead of self.num_blocks for
201+
# local descr, and that makes handling regular flow less clean.
202+
for block_id in range(self.num_blocks):
203+
block_offset = block_id * self.block_len
204+
addr = base_addr + block_offset
205+
# (addr, len, device id)
206+
# TODO: does device_id matter to DRAM?
207+
blocks_data.append((addr, self.block_len, self.tp_rank))
208+
logger.debug("Created %s blocks for src engine %s and rank %s",
209+
len(blocks_data), self.engine_id, self.tp_rank)
210+
211+
descs = self.nixl_wrapper.get_xfer_descs(blocks_data,
212+
self.nixl_memory_type)
213+
# NIXL_INIT_AGENT to be used for preparations of local descs.
214+
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
215+
"NIXL_INIT_AGENT", descs)
216+
217+
# After KV Caches registered, listen for new connections.
218+
metadata = NixlAgentMetadata(
219+
engine_id=self.engine_id,
220+
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
221+
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
222+
num_blocks=self.num_blocks,
223+
block_len=self.block_len,
224+
attn_backend_name=self.backend_name,
225+
kv_cache_layout=self.kv_cache_layout)
226+
ready_event = threading.Event()
227+
self._nixl_handshake_listener_t = threading.Thread(
228+
target=self._nixl_handshake_listener,
229+
args=(metadata, ready_event, self.side_channel_port, self.tp_rank),
230+
daemon=True,
231+
name="nixl_handshake_listener")
232+
self._nixl_handshake_listener_t.start()
233+
ready_event.wait() # Wait for listener ZMQ socket to be ready.
234+
235+
nixl_connector.NixlConnectorWorker.initialize_host_xfer_buffer = initialize_host_xfer_buffer
236+
nixl_connector.NixlConnectorWorker.register_kv_caches = register_kv_caches
237+

vllm_gaudi/platform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
7777
cache_config = vllm_config.cache_config
7878
if cache_config and cache_config.block_size is None:
7979
cache_config.block_size = 128
80+
#vllm_config.kv_transfer_config.kv_buffer_device = 'hpu'
8081
if (parallel_config.distributed_executor_backend in ['mp', 'uni']
8182
and envs.VLLM_WORKER_MULTIPROC_METHOD == 'fork'):
8283
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD",
@@ -213,3 +214,4 @@ def _synced_weight_loader(param, *args, **kwargs):
213214
logger.warning(msg)
214215
import vllm.model_executor.utils as utils
215216
utils.set_weight_attrs = set_weight_attrs
217+

0 commit comments

Comments
 (0)