Skip to content

Commit 2341260

Browse files
DwyaneShiHaiyang Shi
andauthored
[Feature] Add RDMA auto-detection for kvcache (#1194)
- add RDMA auto-detection - enable RDMA auto-detection in InfiniStore Signed-off-by: Haiyang Shi <[email protected]> Co-authored-by: Haiyang Shi <[email protected]>
1 parent b72a0a5 commit 2341260

File tree

6 files changed

+1025
-263
lines changed

6 files changed

+1025
-263
lines changed

python/aibrix_kvcache/aibrix_kvcache/envs.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
# Since 0.2.42, InfiniStore supports RDMA GID index in client config,
102102
# users can specify the GID index of each device in this format:
103103
# "mlx5_0:gid0,mlx5_1:gid1,mlx5_2:gid2"
104-
AIBRIX_KV_CACHE_OL_INFINISTORE_VISIBLE_DEV_LIST: List[str] = ["mlx5_0"]
104+
AIBRIX_KV_CACHE_OL_INFINISTORE_VISIBLE_DEV_LIST: List[str] = []
105105
AIBRIX_KV_CACHE_OL_INFINISTORE_USE_GDR: bool = True
106106

107107
# HPKV Env Vars
@@ -111,6 +111,11 @@
111111
AIBRIX_KV_CACHE_OL_HPKV_LOCAL_PORT: int = 12345
112112
AIBRIX_KV_CACHE_OL_HPKV_USE_GDR: bool = True
113113

114+
# RDMA Auto-Detection Env Vars
115+
# Defines the range of valid GIDs. Similar to NVSHMEM_IB_ADDR_RANGE
116+
# for NVSHMEM. It must be a valid CIDR.
117+
AIBRIX_KV_CACHE_OL_TRANSPORT_RDMA_ADDR_RANGE: str = "::/0"
118+
114119
# The begin-* and end* here are used by the documentation generator
115120
# to extract the used env vars.
116121

@@ -289,10 +294,13 @@
289294
"AIBRIX_KV_CACHE_OL_INFINISTORE_LINK_TYPE", "Ethernet"
290295
).strip()
291296
),
292-
"AIBRIX_KV_CACHE_OL_INFINISTORE_VISIBLE_DEV_LIST": lambda: (
293-
os.getenv("AIBRIX_KV_CACHE_OL_INFINISTORE_VISIBLE_DEV_LIST", "mlx5_0")
294-
.strip()
295-
.split(",")
297+
"AIBRIX_KV_CACHE_OL_INFINISTORE_VISIBLE_DEV_LIST": lambda: list(
298+
filter(
299+
None,
300+
os.getenv("AIBRIX_KV_CACHE_OL_INFINISTORE_VISIBLE_DEV_LIST", "")
301+
.strip()
302+
.split(","),
303+
)
296304
),
297305
"AIBRIX_KV_CACHE_OL_INFINISTORE_USE_GDR": lambda: (
298306
os.getenv("AIBRIX_KV_CACHE_OL_INFINISTORE_USE_GDR", "1").strip().lower()
@@ -315,6 +323,12 @@
315323
os.getenv("AIBRIX_KV_CACHE_OL_HPKV_USE_GDR", "1").strip().lower()
316324
in ("1", "true")
317325
),
326+
# ================== RDMA Auto-Detection Env Vars ==================
327+
"AIBRIX_KV_CACHE_OL_TRANSPORT_RDMA_ADDR_RANGE": lambda: (
328+
os.getenv(
329+
"AIBRIX_KV_CACHE_OL_TRANSPORT_RDMA_ADDR_RANGE", "::/0"
330+
).strip()
331+
),
318332
}
319333

320334
# end-env-vars-definition

python/aibrix_kvcache/aibrix_kvcache/l2/connectors/infinistore.py

Lines changed: 73 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,18 @@
1818

1919
import infinistore
2020
import torch
21+
from validators import ipv4, ipv6
2122

2223
from ... import envs
2324
from ...common import AsyncBase
25+
from ...common.absl_logging import getLogger
2426
from ...memory import MemoryRegion
2527
from ...status import Status, StatusCodes
28+
from ...transport import AddrFamily, DeviceRequest, GIDType, RDMATransport
2629
from . import Connector, ConnectorFeature
2730

31+
logger = getLogger(__name__)
32+
2833

2934
@AsyncBase.async_wrap(delete="_delete")
3035
class InfiniStoreConnector(Connector[bytes, torch.Tensor], AsyncBase):
@@ -52,18 +57,73 @@ def from_envs(
5257
service_port = kwargs.get(
5358
"port", envs.AIBRIX_KV_CACHE_OL_INFINISTORE_SERVICE_PORT
5459
)
60+
assert ipv4(host_addr) or ipv6(host_addr), "Invalid host_addr"
5561
dev_list = envs.AIBRIX_KV_CACHE_OL_INFINISTORE_VISIBLE_DEV_LIST
56-
assert (
57-
len(dev_list) > 0
58-
), "AIBRIX_KV_CACHE_OL_INFINISTORE_VISIBLE_DEV_LIST is empty"
62+
connection_type = envs.AIBRIX_KV_CACHE_OL_INFINISTORE_CONNECTION_TYPE
63+
ib_port = envs.AIBRIX_KV_CACHE_OL_INFINISTORE_IB_PORT
64+
link_type = envs.AIBRIX_KV_CACHE_OL_INFINISTORE_LINK_TYPE
65+
66+
if connection_type == "TCP":
67+
config = infinistore.ClientConfig(
68+
host_addr=host_addr,
69+
service_port=service_port,
70+
connection_type=connection_type,
71+
link_type=link_type,
72+
)
73+
return cls(config, conn_id, executor)
74+
75+
# RDMA
76+
addr_family = (
77+
AddrFamily.AF_INET if ipv4(host_addr) else AddrFamily.AF_INET6
78+
)
79+
gid_type = GIDType.ROCE_V2 if link_type != "IB" else GIDType.IB_ROCE_V1
80+
if len(dev_list) == 0:
81+
logger.info(
82+
"AIBRIX_KV_CACHE_OL_INFINISTORE_VISIBLE_DEV_LIST is not set, "
83+
"trying to auto-detect visible devices"
84+
)
85+
addr_range = envs.AIBRIX_KV_CACHE_OL_TRANSPORT_RDMA_ADDR_RANGE
86+
rdma = RDMATransport(
87+
addr_range=addr_range,
88+
addr_family=addr_family,
89+
gid_type=gid_type,
90+
)
91+
status = rdma.get_device_list()
92+
assert status.is_ok(), f"Failed to get device list: {status}"
93+
94+
devices = status.get()
95+
for d in devices:
96+
dev_list.append(f"{d.device_name}:{d.port_attrs.gid_index}")
97+
else:
98+
requests: List[DeviceRequest] = []
99+
for dev_name in dev_list:
100+
if ":" in dev_name:
101+
splits = dev_name.split(":")
102+
request = DeviceRequest(
103+
device_name=splits[0],
104+
gid_index=int(splits[1]),
105+
)
106+
else:
107+
request = DeviceRequest(device_name=dev_name)
108+
requests.append(request)
109+
rdma = RDMATransport(
110+
request=requests,
111+
addr_family=addr_family,
112+
gid_type=gid_type,
113+
)
114+
status = rdma.get_device_list()
115+
# Only update dev_list if we got a new list
116+
if status.is_ok():
117+
devices = status.get()
118+
dev_list.clear()
119+
for d in devices:
120+
dev_list.append(f"{d.device_name}:{d.port_attrs.gid_index}")
59121

60122
num_visible_gpus = torch.cuda.device_count()
61123

62-
dev_list = dev_list[:num_visible_gpus]
63-
assert num_visible_gpus % len(dev_list) == 0, (
64-
"AIBRIX_KV_CACHE_OL_INFINISTORE_VISIBLE_DEV_LIST is not a "
65-
"multiple of num. of visible GPUs"
66-
)
124+
dev_list = [
125+
dev_list[i % len(dev_list)] for i in range(num_visible_gpus)
126+
]
67127

68128
# For InfiniStore RDMA, we need to map the GPU index to the RNIC
69129
# index to support multi-GPU per RNIC. For example, if we have 8
@@ -75,6 +135,8 @@ def from_envs(
75135
dev_name = dev_list[rnic_idx]
76136
hint_gid_index_str = ""
77137

138+
logger.info(f"InfiniStore selects {dev_name}")
139+
78140
# If dev_name is in the format of "mlx5_i:xxx", then we need to
79141
# extract the dev_name and hint_gid_index from the dev_name.
80142
if ":" in dev_name:
@@ -85,9 +147,9 @@ def from_envs(
85147
config = infinistore.ClientConfig(
86148
host_addr=host_addr,
87149
service_port=service_port,
88-
connection_type=envs.AIBRIX_KV_CACHE_OL_INFINISTORE_CONNECTION_TYPE,
89-
ib_port=envs.AIBRIX_KV_CACHE_OL_INFINISTORE_IB_PORT,
90-
link_type=envs.AIBRIX_KV_CACHE_OL_INFINISTORE_LINK_TYPE,
150+
connection_type=connection_type,
151+
ib_port=ib_port,
152+
link_type=link_type,
91153
dev_name=dev_name,
92154
)
93155

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2024 The Aibrix Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .rdma import AddrFamily, DeviceRequest, GIDType, RDMATransport
16+
17+
__all__ = ["AddrFamily", "DeviceRequest", "GIDType", "RDMATransport"]

0 commit comments

Comments
 (0)