4
4
import importlib
5
5
from typing import TYPE_CHECKING , Callable
6
6
7
+ # yapf: disable
7
8
import vllm .envs as envs
8
- from vllm .distributed .kv_transfer .kv_connector .base import KVConnectorBase
9
+ from vllm .distributed .kv_transfer .kv_connector .base import (
10
+ KVConnectorBase , KVConnectorBaseType )
9
11
from vllm .distributed .kv_transfer .kv_connector .v1 import KVConnectorRole
10
12
from vllm .logger import init_logger
11
13
14
+ # yapf: enable
15
+
12
16
if TYPE_CHECKING :
13
- from vllm .config import VllmConfig
17
+ from vllm .config import KVTransferConfig , VllmConfig
14
18
15
19
logger = init_logger (__name__ )
16
20
@@ -42,17 +46,7 @@ def create_connector(
42
46
f"but found { envs .VLLM_USE_V1 = } " )
43
47
44
48
kv_transfer_config = config .kv_transfer_config
45
- connector_name = kv_transfer_config .kv_connector
46
- if connector_name in cls ._registry :
47
- connector_cls = cls ._registry [connector_name ]()
48
- else :
49
- connector_module_path = kv_transfer_config .kv_connector_module_path
50
- if connector_module_path is None :
51
- raise ValueError (
52
- f"Unsupported connector type: { connector_name } " )
53
- connector_module = importlib .import_module (connector_module_path )
54
- connector_cls = getattr (connector_module , connector_name )
55
- assert issubclass (connector_cls , KVConnectorBase )
49
+ connector_cls = cls .get_connector_class (kv_transfer_config )
56
50
logger .info ("Creating v1 connector with name: %s and engine_id: %s" ,
57
51
connector_cls .__name__ , kv_transfer_config .engine_id )
58
52
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
@@ -65,6 +59,23 @@ def create_connector(
65
59
# We build separately to enforce strict separation
66
60
return connector_cls (config , role )
67
61
62
+ @classmethod
63
+ def get_connector_class (
64
+ cls , kv_transfer_config : "KVTransferConfig"
65
+ ) -> type [KVConnectorBaseType ]:
66
+ """Get the connector class by name."""
67
+ connector_name = kv_transfer_config .kv_connector
68
+ if connector_name in cls ._registry :
69
+ connector_cls = cls ._registry [connector_name ]()
70
+ else :
71
+ connector_module_path = kv_transfer_config .kv_connector_module_path
72
+ if connector_module_path is None :
73
+ raise ValueError (
74
+ f"Unsupported connector type: { connector_name } " )
75
+ connector_module = importlib .import_module (connector_module_path )
76
+ connector_cls = getattr (connector_module , connector_name )
77
+ return connector_cls
78
+
68
79
69
80
# Register various connectors here.
70
81
# The registration should not be done in each individual file, as we want to
0 commit comments