Skip to content

Commit 35dd8ef

Browse files
authored
openxla: patch PJRT client should_stage_host_to_device_transfers (#14)
1 parent e44fd02 commit 35dd8ef

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
From 8b54dc4825628b088ea47e1dda7189477a0989fb Mon Sep 17 00:00:00 2001
2+
From: Hugo Mano <hugo@zml.ai>
3+
Date: Tue, 28 Jan 2025 16:15:15 +0100
4+
Subject: [PATCH] [PJRT] Expose should_stage_host_to_device_transfers as
5+
create option
6+
7+
Expose GPU option `should_stage_host_to_device_transfers `
8+
as configuration in PJRT client create func.
9+
---
10+
xla/pjrt/c/pjrt_c_api_gpu_internal.cc | 7 +++++++
11+
1 file changed, 7 insertions(+)
12+
13+
diff --git a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc
14+
index cf23c997fb..b3157b59b1 100644
15+
--- a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc
16+
+++ b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc
17+
@@ -84,6 +84,7 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
18+
{"visible_devices", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64List},
19+
{"node_id", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64},
20+
{"num_nodes", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64},
21+
+ {"should_stage_host_to_device_transfers", PJRT_NamedValue_Type::PJRT_NamedValue_kBool},
22+
{"enable_mock_nccl", PJRT_NamedValue_Type::PJRT_NamedValue_kBool},
23+
{"mock_gpu_topology", PJRT_NamedValue_Type::PJRT_NamedValue_kString},
24+
});
25+
@@ -139,6 +140,11 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
26+
if (auto it = create_options.find("num_nodes"); it != create_options.end()) {
27+
num_nodes = std::get<int64_t>(it->second);
28+
}
29+
+ bool should_stage_host_to_device_transfers = true;
30+
+ if (auto it = create_options.find("should_stage_host_to_device_transfers");
31+
+ it != create_options.end()) {
32+
+ should_stage_host_to_device_transfers = std::get<bool>(it->second);
33+
+ }
34+
bool enable_mock_nccl = false;
35+
if (auto it = create_options.find("enable_mock_nccl");
36+
it != create_options.end()) {
37+
@@ -159,6 +165,7 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
38+
options.kv_store = pjrt::ToCppKeyValueStore(
39+
args->kv_get_callback, args->kv_get_user_arg, args->kv_try_get_callback,
40+
args->kv_try_get_user_arg, args->kv_put_callback, args->kv_put_user_arg);
41+
+ options.should_stage_host_to_device_transfers = should_stage_host_to_device_transfers;
42+
options.enable_mock_nccl = enable_mock_nccl;
43+
options.mock_gpu_topology = mock_gpu_topology;
44+
PJRT_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
45+
--
46+
2.39.5 (Apple Git-154)

third_party/openxla/xla.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ filegroup(
3030
"patches/20240901-003-Only-export-GetPjrtApi-symbol-on-macOS.patch", # PR: https://github.com/openxla/xla/pull/16696
3131
"patches/20250120-001-Enable-nvptxcompiler-with-nvjitlink.patch", # Allow us to levarage technologies flagged for Google only
3232
"patches/20250122-001-Fix-LoadedNvJitLinkHasKnownIssues-check.patch", # PR: https://github.com/openxla/xla/pull/2172
33+
"patches/20250128-001-PJRT-Expose-should_stage_host_to_device_transfers.patch", # PR: https://github.com/openxla/xla/pull/21965
3334
],
3435
)
3536

0 commit comments

Comments
 (0)