18
18
19
19
import infinistore
20
20
import torch
21
+ from validators import ipv4 , ipv6
21
22
22
23
from ... import envs
23
24
from ...common import AsyncBase
25
+ from ...common .absl_logging import getLogger
24
26
from ...memory import MemoryRegion
25
27
from ...status import Status , StatusCodes
28
+ from ...transport import AddrFamily , DeviceRequest , GIDType , RDMATransport
26
29
from . import Connector , ConnectorFeature
27
30
31
+ logger = getLogger (__name__ )
32
+
28
33
29
34
@AsyncBase .async_wrap (delete = "_delete" )
30
35
class InfiniStoreConnector (Connector [bytes , torch .Tensor ], AsyncBase ):
@@ -52,18 +57,73 @@ def from_envs(
52
57
service_port = kwargs .get (
53
58
"port" , envs .AIBRIX_KV_CACHE_OL_INFINISTORE_SERVICE_PORT
54
59
)
60
+ assert ipv4 (host_addr ) or ipv6 (host_addr ), "Invalid host_addr"
55
61
dev_list = envs .AIBRIX_KV_CACHE_OL_INFINISTORE_VISIBLE_DEV_LIST
56
- assert (
57
- len (dev_list ) > 0
58
- ), "AIBRIX_KV_CACHE_OL_INFINISTORE_VISIBLE_DEV_LIST is empty"
62
+ connection_type = envs .AIBRIX_KV_CACHE_OL_INFINISTORE_CONNECTION_TYPE
63
+ ib_port = envs .AIBRIX_KV_CACHE_OL_INFINISTORE_IB_PORT
64
+ link_type = envs .AIBRIX_KV_CACHE_OL_INFINISTORE_LINK_TYPE
65
+
66
+ if connection_type == "TCP" :
67
+ config = infinistore .ClientConfig (
68
+ host_addr = host_addr ,
69
+ service_port = service_port ,
70
+ connection_type = connection_type ,
71
+ link_type = link_type ,
72
+ )
73
+ return cls (config , conn_id , executor )
74
+
75
+ # RDMA
76
+ addr_family = (
77
+ AddrFamily .AF_INET if ipv4 (host_addr ) else AddrFamily .AF_INET6
78
+ )
79
+ gid_type = GIDType .ROCE_V2 if link_type != "IB" else GIDType .IB_ROCE_V1
80
+ if len (dev_list ) == 0 :
81
+ logger .info (
82
+ "AIBRIX_KV_CACHE_OL_INFINISTORE_VISIBLE_DEV_LIST is not set, "
83
+ "trying to auto-detect visible devices"
84
+ )
85
+ addr_range = envs .AIBRIX_KV_CACHE_OL_TRANSPORT_RDMA_ADDR_RANGE
86
+ rdma = RDMATransport (
87
+ addr_range = addr_range ,
88
+ addr_family = addr_family ,
89
+ gid_type = gid_type ,
90
+ )
91
+ status = rdma .get_device_list ()
92
+ assert status .is_ok (), f"Failed to get device list: { status } "
93
+
94
+ devices = status .get ()
95
+ for d in devices :
96
+ dev_list .append (f"{ d .device_name } :{ d .port_attrs .gid_index } " )
97
+ else :
98
+ requests : List [DeviceRequest ] = []
99
+ for dev_name in dev_list :
100
+ if ":" in dev_name :
101
+ splits = dev_name .split (":" )
102
+ request = DeviceRequest (
103
+ device_name = splits [0 ],
104
+ gid_index = int (splits [1 ]),
105
+ )
106
+ else :
107
+ request = DeviceRequest (device_name = dev_name )
108
+ requests .append (request )
109
+ rdma = RDMATransport (
110
+ request = requests ,
111
+ addr_family = addr_family ,
112
+ gid_type = gid_type ,
113
+ )
114
+ status = rdma .get_device_list ()
115
+ # Only update dev_list if we got a new list
116
+ if status .is_ok ():
117
+ devices = status .get ()
118
+ dev_list .clear ()
119
+ for d in devices :
120
+ dev_list .append (f"{ d .device_name } :{ d .port_attrs .gid_index } " )
59
121
60
122
num_visible_gpus = torch .cuda .device_count ()
61
123
62
- dev_list = dev_list [:num_visible_gpus ]
63
- assert num_visible_gpus % len (dev_list ) == 0 , (
64
- "AIBRIX_KV_CACHE_OL_INFINISTORE_VISIBLE_DEV_LIST is not a "
65
- "multiple of num. of visible GPUs"
66
- )
124
+ dev_list = [
125
+ dev_list [i % len (dev_list )] for i in range (num_visible_gpus )
126
+ ]
67
127
68
128
# For InfiniStore RDMA, we need to map the GPU index to the RNIC
69
129
# index to support multi-GPU per RNIC. For example, if we have 8
@@ -75,6 +135,8 @@ def from_envs(
75
135
dev_name = dev_list [rnic_idx ]
76
136
hint_gid_index_str = ""
77
137
138
+ logger .info (f"InfiniStore selects { dev_name } " )
139
+
78
140
# If dev_name is in the format of "mlx5_i:xxx", then we need to
79
141
# extract the dev_name and hint_gid_index from the dev_name.
80
142
if ":" in dev_name :
@@ -85,9 +147,9 @@ def from_envs(
85
147
config = infinistore .ClientConfig (
86
148
host_addr = host_addr ,
87
149
service_port = service_port ,
88
- connection_type = envs . AIBRIX_KV_CACHE_OL_INFINISTORE_CONNECTION_TYPE ,
89
- ib_port = envs . AIBRIX_KV_CACHE_OL_INFINISTORE_IB_PORT ,
90
- link_type = envs . AIBRIX_KV_CACHE_OL_INFINISTORE_LINK_TYPE ,
150
+ connection_type = connection_type ,
151
+ ib_port = ib_port ,
152
+ link_type = link_type ,
91
153
dev_name = dev_name ,
92
154
)
93
155
0 commit comments