1+ From 8b54dc4825628b088ea47e1dda7189477a0989fb Mon Sep 17 00:00:00 2001
2+ From: Hugo Mano <hugo@zml.ai>
3+ Date: Tue, 28 Jan 2025 16:15:15 +0100
4+ Subject: [PATCH] [PJRT] Expose should_stage_host_to_device_transfers as
5+ create option
6+
7+ Expose GPU option `should_stage_host_to_device_transfers `
8+ as configuration in PJRT client create func.
9+ ---
10+ xla/pjrt/c/pjrt_c_api_gpu_internal.cc | 7 +++++++
11+ 1 file changed, 7 insertions(+)
12+
13+ diff --git a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc
14+ index cf23c997fb..b3157b59b1 100644
15+ --- a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc
16+ +++ b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc
17+ @@ -84,6 +84,7 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
18+ {"visible_devices", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64List},
19+ {"node_id", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64},
20+ {"num_nodes", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64},
21+ + {"should_stage_host_to_device_transfers", PJRT_NamedValue_Type::PJRT_NamedValue_kBool},
22+ {"enable_mock_nccl", PJRT_NamedValue_Type::PJRT_NamedValue_kBool},
23+ {"mock_gpu_topology", PJRT_NamedValue_Type::PJRT_NamedValue_kString},
24+ });
25+ @@ -139,6 +140,11 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
26+ if (auto it = create_options.find("num_nodes"); it != create_options.end()) {
27+ num_nodes = std::get<int64_t>(it->second);
28+ }
29+ + bool should_stage_host_to_device_transfers = true;
30+ + if (auto it = create_options.find("should_stage_host_to_device_transfers");
31+ + it != create_options.end()) {
32+ + should_stage_host_to_device_transfers = std::get<bool>(it->second);
33+ + }
34+ bool enable_mock_nccl = false;
35+ if (auto it = create_options.find("enable_mock_nccl");
36+ it != create_options.end()) {
37+ @@ -159,6 +165,7 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
38+ options.kv_store = pjrt::ToCppKeyValueStore(
39+ args->kv_get_callback, args->kv_get_user_arg, args->kv_try_get_callback,
40+ args->kv_try_get_user_arg, args->kv_put_callback, args->kv_put_user_arg);
41+ + options.should_stage_host_to_device_transfers = should_stage_host_to_device_transfers;
42+ options.enable_mock_nccl = enable_mock_nccl;
43+ options.mock_gpu_topology = mock_gpu_topology;
44+ PJRT_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
45+ - -
46+ 2.39.5 (Apple Git-154)
0 commit comments