Skip to content

Commit f413523

Browse files
authored
feat(distributed): add get_required_kvcache_layout class method to kv connector api (#20433)
Signed-off-by: wxsm <[email protected]>
1 parent 4904e53 commit f413523

File tree

7 files changed

+186
-28
lines changed

7 files changed

+186
-28
lines changed

tests/distributed/test_kvlayout.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from vllm.config import (DeviceConfig, KVTransferConfig, ModelConfig,
5+
VllmConfig, set_current_vllm_config)
6+
from vllm.distributed.kv_transfer.kv_connector.utils import (
7+
get_kv_connector_cache_layout)
8+
from vllm.logger import init_logger
9+
10+
logger = init_logger("test_expert_parallel")
11+
12+
13+
def test_get_kv_connector_cache_layout_without_kv_connector():
14+
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"))
15+
with set_current_vllm_config(vllm_config):
16+
# Test with default settings
17+
layout = get_kv_connector_cache_layout()
18+
assert layout == "NHD"
19+
20+
21+
def test_get_kv_connector_cache_layout_with_lmcache_connector():
22+
kv_transfer_config = KVTransferConfig(
23+
kv_connector="LMCacheConnectorV1",
24+
kv_role="kv_both",
25+
)
26+
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
27+
kv_transfer_config=kv_transfer_config)
28+
with set_current_vllm_config(vllm_config):
29+
# Test with default settings
30+
layout = get_kv_connector_cache_layout()
31+
assert layout == "NHD"
32+
33+
34+
def test_get_kv_connector_cache_layout_with_nixl_connector():
35+
kv_transfer_config = KVTransferConfig(
36+
kv_connector="NixlConnector",
37+
kv_role="kv_both",
38+
)
39+
model_config = ModelConfig()
40+
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
41+
model_config=model_config,
42+
kv_transfer_config=kv_transfer_config)
43+
with set_current_vllm_config(vllm_config):
44+
# Test with default settings
45+
layout = get_kv_connector_cache_layout()
46+
assert layout == "HND"
47+
48+
49+
def test_get_kv_connector_cache_layout_with_multi_connector():
50+
kv_transfer_config = KVTransferConfig(kv_connector="MultiConnector",
51+
kv_role="kv_both",
52+
kv_connector_extra_config={
53+
"connectors": [{
54+
"kv_connector":
55+
"SharedStorageConnector",
56+
"kv_role":
57+
"kv_both"
58+
}, {
59+
"kv_connector":
60+
"NixlConnector",
61+
"kv_role":
62+
"kv_both"
63+
}]
64+
})
65+
model_config = ModelConfig()
66+
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
67+
model_config=model_config,
68+
kv_transfer_config=kv_transfer_config)
69+
with set_current_vllm_config(vllm_config):
70+
# Test with default settings
71+
layout = get_kv_connector_cache_layout()
72+
assert layout == "HND"

vllm/distributed/kv_transfer/kv_connector/base.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010

1111
from abc import ABC, abstractmethod
12-
from typing import TYPE_CHECKING, Union
12+
from typing import TYPE_CHECKING, Optional, Union
1313

1414
import torch
1515

@@ -124,5 +124,19 @@ def recv_kv_caches_and_hidden_states(
124124

125125
raise NotImplementedError
126126

127+
@classmethod
128+
def get_required_kvcache_layout(
129+
cls, vllm_config: "VllmConfig") -> Optional[str]:
130+
"""
131+
Get the required KV cache layout for this connector.
132+
Args:
133+
vllm_config (VllmConfig): the vllm config.
134+
135+
Returns:
136+
str: the required KV cache layout. e.g. HND, or NHD.
137+
None if the connector does not require a specific layout.
138+
"""
139+
return None
140+
127141

128142
KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1]

vllm/distributed/kv_transfer/kv_connector/factory.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import TYPE_CHECKING, Callable
66

77
import vllm.envs as envs
8+
from vllm.config import KVTransferConfig
89
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
910
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
1011
KVConnectorRole)
@@ -41,25 +42,15 @@ def create_connector_v0(cls, rank: int, local_rank: int,
4142
raise ValueError("Attempting to initialize a V0 Connector, "
4243
f"but found {envs.VLLM_USE_V1=}")
4344

44-
connector_name = config.kv_transfer_config.kv_connector
45-
if connector_name not in cls._registry:
46-
raise ValueError(f"Unsupported connector type: {connector_name}")
47-
48-
connector_cls = cls._registry[connector_name]()
45+
connector_cls = cls.get_connector_class(config.kv_transfer_config)
4946
assert issubclass(connector_cls, KVConnectorBase)
5047
return connector_cls(rank, local_rank, config)
5148

5249
@classmethod
53-
def create_connector_v1(
54-
cls,
55-
config: "VllmConfig",
56-
role: KVConnectorRole,
57-
) -> KVConnectorBase_V1:
58-
if not envs.VLLM_USE_V1:
59-
raise ValueError("Attempting to initialize a V1 Connector, "
60-
f"but found {envs.VLLM_USE_V1=}")
61-
62-
kv_transfer_config = config.kv_transfer_config
50+
def get_connector_class(
51+
cls, kv_transfer_config: "KVTransferConfig"
52+
) -> type[KVConnectorBaseType]:
53+
"""Get the connector class by name."""
6354
connector_name = kv_transfer_config.kv_connector
6455
if connector_name in cls._registry:
6556
connector_cls = cls._registry[connector_name]()
@@ -70,9 +61,23 @@ def create_connector_v1(
7061
f"Unsupported connector type: {connector_name}")
7162
connector_module = importlib.import_module(connector_module_path)
7263
connector_cls = getattr(connector_module, connector_name)
64+
return connector_cls
65+
66+
@classmethod
67+
def create_connector_v1(
68+
cls,
69+
config: "VllmConfig",
70+
role: KVConnectorRole,
71+
) -> KVConnectorBase_V1:
72+
if not envs.VLLM_USE_V1:
73+
raise ValueError("Attempting to initialize a V1 Connector, "
74+
f"but found {envs.VLLM_USE_V1=}")
75+
76+
kv_transfer_config = config.kv_transfer_config
77+
connector_cls = cls.get_connector_class(kv_transfer_config)
7378
assert issubclass(connector_cls, KVConnectorBase_V1)
7479
logger.info("Creating v1 connector with name: %s and engine_id: %s",
75-
connector_name, kv_transfer_config.engine_id)
80+
connector_cls.__name__, kv_transfer_config.engine_id)
7681
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
7782
# Scheduler connector:
7883
# - Co-locate with scheduler process

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import vllm.envs as envs
1414
from vllm import _custom_ops as ops
1515
from vllm.config import VllmConfig, get_current_vllm_config
16+
from vllm.distributed.kv_transfer.kv_connector.factory import (
17+
KVConnectorFactory)
1618
from vllm.logger import init_logger
1719
from vllm.v1.outputs import ModelRunnerOutput
1820

@@ -103,15 +105,14 @@ def get_kv_connector_cache_layout():
103105
# used for faster transfer.
104106
vllm_config = get_current_vllm_config()
105107
kv_config = vllm_config.kv_transfer_config
106-
if kv_config is not None and vllm_config.model_config is None:
107-
logger.warning_once("Unable to detect current VLLM config. " \
108-
"Defaulting to NHD kv cache layout.")
109-
elif kv_config is not None:
110-
use_mla = vllm_config.model_config.use_mla
111-
if not use_mla and kv_config.kv_connector == "NixlConnector":
112-
logger.info_once("NixlConnector detected. Setting KV cache " \
113-
"layout to HND for better xfer performance.")
114-
return "HND"
108+
if kv_config is not None:
109+
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
110+
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
111+
vllm_config)
112+
if required_kvcache_layout is not None:
113+
return required_kvcache_layout
114+
logger.info_once("Connectors do not specify a " \
115+
"kv cache layout, defaulting to NHD.")
115116
return "NHD"
116117

117118

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,17 @@ def request_finished(
299299
returned by the engine.
300300
"""
301301
return False, None
302+
303+
@classmethod
304+
def get_required_kvcache_layout(
305+
cls, vllm_config: "VllmConfig") -> Optional[str]:
306+
"""
307+
Get the required KV cache layout for this connector.
308+
Args:
309+
vllm_config (VllmConfig): the vllm config.
310+
311+
Returns:
312+
str: the required KV cache layout. e.g. HND, or NHD.
313+
None if the connector does not require a specific layout.
314+
"""
315+
return None

vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,36 @@ def request_finished(
202202
self._requests_to_connector.pop(request.request_id, None)
203203

204204
return async_saves > 0, kv_txfer_params
205+
206+
@classmethod
207+
def get_required_kvcache_layout(
208+
cls, vllm_config: "VllmConfig") -> Optional[str]:
209+
"""
210+
Get the required KV cache layout for this connector.
211+
Args:
212+
vllm_config (VllmConfig): the vllm config.
213+
214+
Returns:
215+
str: the required KV cache layout. e.g. HND, or NHD.
216+
None if the connector does not require a specific layout.
217+
"""
218+
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
219+
"connectors")
220+
assert ktcs is not None
221+
layouts: set[str] = set()
222+
temp_vllm_config = copy.copy(vllm_config)
223+
for ktc in ktcs:
224+
kv_transfer_config = KVTransferConfig(**ktc)
225+
temp_vllm_config.kv_transfer_config = kv_transfer_config
226+
required_kvcache_layout = KVConnectorFactory.get_connector_class(
227+
kv_transfer_config).get_required_kvcache_layout(
228+
temp_vllm_config)
229+
if required_kvcache_layout is not None:
230+
layouts.add(required_kvcache_layout)
231+
232+
if len(layouts) > 1:
233+
raise ValueError(f"KV cache layout mismatch: "
234+
f"found {len(layouts)} different layouts "
235+
f"({', '.join(layouts) })."
236+
f"All connectors must use the same layout.")
237+
return next(iter(layouts), None)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,25 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
133133
self.connector_worker = NixlConnectorWorker(
134134
vllm_config, self.engine_id)
135135

136+
############################################################
137+
# Class Methods
138+
############################################################
139+
@classmethod
140+
def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
141+
if vllm_config.model_config is None:
142+
logger.warning_once("Unable to detect current VLLM config. "
143+
"Fallback to default kv cache layout.")
144+
return None
145+
use_mla = vllm_config.model_config.use_mla
146+
if use_mla:
147+
# return None when we have mla
148+
# as the layout should not matter in that case,
149+
# which fallback to the default behavior.
150+
return None
151+
logger.info_once("NixlConnector setting KV cache "
152+
"layout to HND for better xfer performance.")
153+
return "HND"
154+
136155
############################################################
137156
# Scheduler Side Methods
138157
############################################################
@@ -236,13 +255,13 @@ def get_num_new_matched_tokens(
236255
"""
237256
For remote prefill, pull all prompt blocks from remote
238257
asynchronously relative to engine execution.
239-
258+
240259
Args:
241260
request (Request): the request object.
242261
num_computed_tokens (int): the number of locally
243262
computed tokens for this request
244263
Returns:
245-
* the number of tokens that can be loaded from the
264+
* the number of tokens that can be loaded from the
246265
external KV cache beyond what is already computed.
247266
* true if the external KV cache tokens will be loaded
248267
asynchronously (between scheduler steps).

0 commit comments

Comments
 (0)