Skip to content

Commit d030b01

Browse files
NickLucchenjhill
andauthored
[BugFix][Nixl][PD] Fix heterogenous TP (#22663)
Signed-off-by: NickLucche <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 767e63b commit d030b01

File tree

2 files changed

+31
-17
lines changed

2 files changed

+31
-17
lines changed

vllm/distributed/kv_transfer/kv_connector/factory.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44
import importlib
55
from typing import TYPE_CHECKING, Callable
66

7+
# yapf: disable
78
import vllm.envs as envs
8-
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
9+
from vllm.distributed.kv_transfer.kv_connector.base import (
10+
KVConnectorBase, KVConnectorBaseType)
911
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
1012
from vllm.logger import init_logger
1113

14+
# yapf: enable
15+
1216
if TYPE_CHECKING:
13-
from vllm.config import VllmConfig
17+
from vllm.config import KVTransferConfig, VllmConfig
1418

1519
logger = init_logger(__name__)
1620

@@ -42,17 +46,7 @@ def create_connector(
4246
f"but found {envs.VLLM_USE_V1=}")
4347

4448
kv_transfer_config = config.kv_transfer_config
45-
connector_name = kv_transfer_config.kv_connector
46-
if connector_name in cls._registry:
47-
connector_cls = cls._registry[connector_name]()
48-
else:
49-
connector_module_path = kv_transfer_config.kv_connector_module_path
50-
if connector_module_path is None:
51-
raise ValueError(
52-
f"Unsupported connector type: {connector_name}")
53-
connector_module = importlib.import_module(connector_module_path)
54-
connector_cls = getattr(connector_module, connector_name)
55-
assert issubclass(connector_cls, KVConnectorBase)
49+
connector_cls = cls.get_connector_class(kv_transfer_config)
5650
logger.info("Creating v1 connector with name: %s and engine_id: %s",
5751
connector_cls.__name__, kv_transfer_config.engine_id)
5852
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
@@ -65,6 +59,23 @@ def create_connector(
6559
# We build separately to enforce strict separation
6660
return connector_cls(config, role)
6761

62+
@classmethod
63+
def get_connector_class(
64+
cls, kv_transfer_config: "KVTransferConfig"
65+
) -> type[KVConnectorBaseType]:
66+
"""Get the connector class by name."""
67+
connector_name = kv_transfer_config.kv_connector
68+
if connector_name in cls._registry:
69+
connector_cls = cls._registry[connector_name]()
70+
else:
71+
connector_module_path = kv_transfer_config.kv_connector_module_path
72+
if connector_module_path is None:
73+
raise ValueError(
74+
f"Unsupported connector type: {connector_name}")
75+
connector_module = importlib.import_module(connector_module_path)
76+
connector_cls = getattr(connector_module, connector_name)
77+
return connector_cls
78+
6879

6980
# Register various connectors here.
7081
# The registration should not be done in each individual file, as we want to

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import vllm.envs as envs
1414
from vllm import _custom_ops as ops
1515
from vllm.config import VllmConfig, get_current_vllm_config
16-
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
17-
KVConnectorBase_V1)
16+
from vllm.distributed.kv_transfer.kv_connector.factory import (
17+
KVConnectorFactory)
1818
from vllm.logger import init_logger
1919
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
2020

@@ -106,8 +106,9 @@ def get_kv_connector_cache_layout():
106106
vllm_config = get_current_vllm_config()
107107
kv_config = vllm_config.kv_transfer_config
108108
if kv_config is not None:
109-
required_kvcache_layout = (
110-
KVConnectorBase_V1.get_required_kvcache_layout(vllm_config))
109+
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
110+
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
111+
vllm_config)
111112
if required_kvcache_layout is not None:
112113
return required_kvcache_layout
113114
logger.info_once("Connectors do not specify a " \
@@ -143,6 +144,8 @@ def update_finished_set(req_ids: Optional[set[str]],
143144
finished_recving = set[str]()
144145
for output in outputs:
145146
output = output.kv_connector_output
147+
if not output:
148+
continue
146149
update_finished_set(output.finished_sending,
147150
self._send_remaining_count, finished_sending)
148151
update_finished_set(output.finished_recving,

0 commit comments

Comments
 (0)