Skip to content

Commit 686cb76

Browse files
committed
feat: add support for custom compile options in torch_xla.compile and PJRT backend
This change introduces the ability to pass custom compile options from Python down to the PJRT backend, allowing users to fine-tune XLA compilation behavior without modifying core code. Key changes: * Python API * Added custom_compile_options parameter to torch_xla.compile for passing compile-time options as a dict (supports bool, float, int, and str values). * Added torch_xla.set_custom_compile_options() utility for setting compile options globally. * Added internal binding _XLAC._set_custom_compile_options(). * C++ Runtime * Added SetCustomCompileOptions() virtual method to ComputationClient and implemented it in PjRtComputationClient. * PjRtComputationClient now stores custom_compile_options_ and injects them into xla::CompileOptions.env_option_overrides during compilation. * Options are stringified before being passed to XLA for compatibility. Motivation:
This enables advanced users to pass through backend-specific tuning flags (e.g., enabling experimental optimizations, toggling partitioning strategies) without hardcoding them, improving flexibility for research and debugging workflows.
1 parent 24bb34c commit 686cb76

File tree

6 files changed

+55
-1
lines changed

6 files changed

+55
-1
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3310,6 +3310,16 @@ void InitXlaModuleBindings(py::module m) {
33103310
XLA_ERROR() << "Could not get the buffer pointer for XLATensor "
33113311
"without a data handle or an IR.";
33123312
})
3313+
.def("_set_custom_compile_options",
3314+
[](const py::dict& compile_options) {
3315+
std::unordered_map<std::string, std::string> options;
3316+
for (const auto& item : compile_options) {
3317+
std::string key = item.first.cast<std::string>();
3318+
options[key] = py::str(item.second).cast<std::string>();
3319+
}
3320+
runtime::GetComputationClientOrDie()->SetCustomCompileOptions(
3321+
options);
3322+
})
33133323
.def(
33143324
// from an XLA tensor to a PyCapsule.
33153325
// When consuming the PyCapsule, we should synchronize

torch_xla/csrc/runtime/computation_client.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,9 @@ class ComputationClient {
446446
// after the last ':' character of the device string.
447447
static int64_t GetDeviceOrdinal(const std::string& device);
448448

449+
virtual void SetCustomCompileOptions(
450+
const std::unordered_map<std::string, std::string>& options) = 0;
451+
449452
protected:
450453
static constexpr auto spmd_device_str = "SPMD:0";
451454

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,11 @@ class IfrtComputationClient : public ComputationClient {
172172
XLA_ERROR() << __FUNCTION__ << " not implemented";
173173
}
174174

175+
void SetCustomCompileOptions(
176+
const std::unordered_map<std::string, std::string>& options) override {
177+
XLA_ERROR() << __FUNCTION__ << " not implemented";
178+
}
179+
175180
// Creates a new instance of IfrtComputationClient and initializes it.
176181
static absl::StatusOr<absl_nonnull std::unique_ptr<IfrtComputationClient>>
177182
Create();

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,13 +555,18 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
555555

556556
for (auto& instance : instances) {
557557
xla::CompileOptions compile_options;
558+
for (auto& option : custom_compile_options_) {
559+
compile_options.env_option_overrides.push_back(
560+
{option.first, option.second});
561+
}
558562
if (enable_cm_in_mp) {
559563
compile_options.executable_build_options.set_use_spmd_partitioning(true);
560564
compile_options.env_option_overrides.push_back(
561565
{"xla_tpu_decompose_all_gather_einsum", true});
562566
compile_options.env_option_overrides.push_back(
563567
{"xla_tpu_decompose_einsum_reduce_scatter", true});
564568
}
569+
565570
if (instance.is_sharded) {
566571
// TODO(yeounoh) multi-host, multi-slice configurations
567572
compile_options.executable_build_options.set_use_spmd_partitioning(true);
@@ -1056,5 +1061,14 @@ void PjRtComputationClient::OnReadyCallback(
10561061
[callback](absl::Status unused) { callback(); });
10571062
}
10581063

1064+
void PjRtComputationClient::SetCustomCompileOptions(
1065+
const std::unordered_map<std::string, std::string>& options) {
1066+
// Stringfy values
1067+
custom_compile_options_.clear();
1068+
for (const auto& [key, value] : options) {
1069+
custom_compile_options_[key] = value;
1070+
}
1071+
}
1072+
10591073
} // namespace runtime
10601074
} // namespace torch_xla

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,9 @@ class PjRtComputationClient : public ComputationClient {
165165
void OnReadyCallback(DataPtr data,
166166
const std::function<void()>& callback) override;
167167

168+
void SetCustomCompileOptions(
169+
const std::unordered_map<std::string, std::string>& options) override;
170+
168171
// Creates a new instance of PjRtComputationClient and initializes it.
169172
static absl::StatusOr<absl_nonnull std::unique_ptr<PjRtComputationClient>>
170173
Create();
@@ -197,6 +200,7 @@ class PjRtComputationClient : public ComputationClient {
197200
// If not nullptr, invoke this instead of the actual XLA compilation. Used
198201
// only for testing.
199202
std::function<absl::Status()> fake_xla_compile_ = nullptr;
203+
std::unordered_map<std::string, std::string> custom_compile_options_;
200204

201205
xla::PjRtDevice* StringToPjRtDevice(const std::string& device);
202206

torch_xla/torch_xla.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def compile(
116116
full_graph: Optional[bool] = False,
117117
name: Optional[str] = None,
118118
max_different_graphs: Optional[int] = None,
119+
custom_compile_options: Optional[dict] = None,
119120
):
120121
"""
121122
Optimizes given model/function using torch_xla's LazyTensor tracing mode.
@@ -136,6 +137,8 @@ def compile(
136137
max_different_graphs (Optional[int]): number of different traced graphs of the given
137138
model/function that we are allowed to have. An error will be raised in case this limit
138139
is exceeded.
140+
custom_compile_options (Optional[dict]): A dictionary of custom compile options to be set.
141+
The keys are strings and the values can be of type bool, float, int, or str.
139142
140143
Example::
141144
@@ -214,7 +217,8 @@ def _compile():
214217
sync()
215218
torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status)
216219
torch_xla._XLAC._set_current_graph_name(saved_current_graph_name)
217-
220+
if custom_compile_options is not None and len(custom_compile_options) > 0:
221+
torch_xla._XLAC._set_custom_compile_options(custom_compile_options)
218222
return _compile() if f is None else _compile()(f)
219223

220224

@@ -264,3 +268,17 @@ def launch(
264268
fn(xu.getenv_as(xenv.LOCAL_RANK, int), *args)
265269
else:
266270
xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)
271+
272+
def set_custom_compile_options(
273+
options: Optional[dict] = None,
274+
):
275+
"""Sets custom compile options for the XLA compilation.
276+
277+
Args:
278+
options: A dictionary of custom compile options to be set.
279+
The keys are strings and the values can be of type bool, float, int, or str.
280+
"""
281+
if options is None:
282+
options = {}
283+
torch_xla._XLAC._set_custom_compile_options(options)
284+

0 commit comments

Comments
 (0)