Skip to content

Commit c3e2978

Browse files
authored
[NIXL] fix cpu PD after physical <> logical block_size PR (#28904)
Signed-off-by: Chendi Xue <[email protected]>
1 parent e4bb268 commit c3e2978

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128}
5555
# Find the git repository root directory
5656
GIT_ROOT=$(git rev-parse --show-toplevel)
5757

58-
SMI_BIN=$(which nvidia-smi || which rocm-smi)
58+
SMI_BIN=$(which nvidia-smi || which rocm-smi || echo "")
5959

6060
# Trap the SIGINT signal (triggered by Ctrl+C)
6161
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
@@ -91,8 +91,13 @@ get_model_args() {
9191
get_num_gpus() {
9292
if [[ "$SMI_BIN" == *"nvidia"* ]]; then
9393
echo "$($SMI_BIN --query-gpu=name --format=csv,noheader | wc -l)"
94-
else
94+
elif [[ "$SMI_BIN" == *"rocm"* ]]; then
9595
echo "$($SMI_BIN -l | grep GPU | wc -l)"
96+
else
97+
# works for non-cuda platforms,
98+
# assuming at least 1 device and
99+
# let system to decide which card to use
100+
echo "1"
96101
fi
97102
}
98103

tools/install_nixl_from_source_ubuntu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def install_system_dependencies():
9595
"meson",
9696
"libtool",
9797
"libtool-bin",
98+
"pkg-config",
9899
]
99100
run_command(["apt-get", "update"])
100101
run_command(["apt-get", "install", "-y"] + apt_packages)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,6 +1161,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
11611161
# to better exploit the memory layout (ie num_blocks is the first dim).
11621162
split_k_and_v = self.kv_topo.split_k_and_v
11631163
tensor_size_bytes = None
1164+
1165+
# TODO (NickLucche): Get kernel_block_size in a cleaner way
1166+
# NHD default "view" for non-MLA cache
1167+
if self.device_type == "cpu":
1168+
block_size_position = -2
1169+
else:
1170+
block_size_position = -2 if self.use_mla else -3
1171+
11641172
# Enable different block lengths for different layers when MLA is used.
11651173
self.block_len_per_layer = list[int]()
11661174
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
@@ -1175,9 +1183,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
11751183
if base_addr in seen_base_addresses:
11761184
continue
11771185

1178-
# TODO (NickLucche): Get kernel_block_size in a cleaner way
1179-
# NHD default "view" for non-MLA cache
1180-
kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3]
1186+
kernel_block_size = cache.shape[block_size_position]
11811187

11821188
if self.block_size != kernel_block_size:
11831189
logger.info_once(

0 commit comments

Comments
 (0)