Skip to content

Commit 00c211a

Browse files
committed
rpc : add support for multiple devices
Allow rpc-server to expose multiple devices from a single endpoint. Change RPC protocol to include device identifier where needed. Add new API to get the device count from an RPC endpoint. closes: ggml-org#15210
1 parent 54dbc37 commit 00c211a

File tree

4 files changed

+318
-147
lines changed

4 files changed

+318
-147
lines changed

common/arg.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,17 +1721,26 @@ static void add_rpc_devices(const std::string & servers) {
17211721
if (!rpc_reg) {
17221722
throw std::invalid_argument("failed to find RPC backend");
17231723
}
1724-
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
1724+
typedef int (*ggml_backend_rpc_get_device_count_t)(const char * endpoint);
1725+
ggml_backend_rpc_get_device_count_t ggml_backend_rpc_get_device_count_fn =
1726+
(ggml_backend_rpc_get_device_count_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_get_device_count");
1727+
if (!ggml_backend_rpc_get_device_count_fn) {
1728+
throw std::invalid_argument("failed to find RPC device count function");
1729+
}
1730+
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint, uint32_t device);
17251731
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
17261732
if (!ggml_backend_rpc_add_device_fn) {
17271733
throw std::invalid_argument("failed to find RPC device add function");
17281734
}
17291735
for (const auto & server : rpc_servers) {
1730-
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
1731-
if (dev) {
1732-
ggml_backend_device_register(dev);
1733-
} else {
1734-
throw std::invalid_argument("failed to register RPC device");
1736+
int dev_count = ggml_backend_rpc_get_device_count_fn(server.c_str());
1737+
for (int i = 0; i < dev_count; i++) {
1738+
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str(), i);
1739+
if (dev) {
1740+
ggml_backend_device_register(dev);
1741+
} else {
1742+
throw std::invalid_argument("failed to register RPC device");
1743+
}
17351744
}
17361745
}
17371746
}

ggml/include/ggml-rpc.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,26 @@
77
extern "C" {
88
#endif
99

10-
#define RPC_PROTO_MAJOR_VERSION 2
10+
#define RPC_PROTO_MAJOR_VERSION 3
1111
#define RPC_PROTO_MINOR_VERSION 0
1212
#define RPC_PROTO_PATCH_VERSION 0
1313
#define GGML_RPC_MAX_SERVERS 16
1414

1515
// backend API
16-
GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
16+
GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device);
1717
GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend);
1818

19-
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
19+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device);
2020

21-
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
21+
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
2222

23-
GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
24-
const char * cache_dir,
25-
size_t free_mem, size_t total_mem);
23+
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
24+
size_t dev_count, ggml_backend_t * backends,
25+
size_t * free_mem, size_t * total_mem);
2626

2727
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
28-
29-
GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
28+
GGML_BACKEND_API int ggml_backend_rpc_get_device_count(const char * endpoint);
29+
GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint, uint32_t device);
3030

3131
#ifdef __cplusplus
3232
}

0 commit comments

Comments
 (0)