Skip to content

Commit 5be40a0

Browse files
committed
feat: add support for custom compile options in torch_xla.compile and PJRT backend
This change introduces the ability to pass custom compile options from Python down to the PJRT backend, allowing users to fine-tune XLA compilation behavior without modifying core code. Key changes: * Python API * Added custom_compile_options parameter to torch_xla.compile for passing compile-time options as a dict (supports bool, float, int, and str values). * Added torch_xla.set_custom_compile_options() utility for setting compile options globally. * Added internal binding _XLAC._set_custom_compile_options(). * C++ Runtime * Added SetCustomCompileOptions() virtual method to ComputationClient and implemented it in PjRtComputationClient. * PjRtComputationClient now stores custom_compile_options_ and injects them into xla::CompileOptions.env_option_overrides during compilation. * Options are stringified before being passed to XLA for compatibility. Motivation:
This enables advanced users to pass through backend-specific tuning flags (e.g., enabling experimental optimizations, toggling partitioning strategies) without hardcoding them, improving flexibility for research and debugging workflows.
1 parent 7d989d1 commit 5be40a0

File tree

6 files changed

+55
-1
lines changed

6 files changed

+55
-1
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3187,6 +3187,16 @@ void InitXlaModuleBindings(py::module m) {
31873187
XLA_ERROR() << "Could not get the buffer pointer for XLATensor "
31883188
"without a data handle or an IR.";
31893189
})
3190+
.def("_set_custom_compile_options",
3191+
[](const py::dict& compile_options) {
3192+
std::unordered_map<std::string, std::string> options;
3193+
for (const auto& item : compile_options) {
3194+
std::string key = item.first.cast<std::string>();
3195+
options[key] = py::str(item.second).cast<std::string>();
3196+
}
3197+
runtime::GetComputationClientOrDie()->SetCustomCompileOptions(
3198+
options);
3199+
})
31903200
.def(
31913201
// from an XLA tensor to a PyCapsule.
31923202
// When consuming the PyCapsule, we should synchronize

torch_xla/csrc/runtime/computation_client.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,9 @@ class ComputationClient {
444444
// after the last ':' character of the device string.
445445
static int64_t GetDeviceOrdinal(const std::string& device);
446446

447+
virtual void SetCustomCompileOptions(
448+
const std::unordered_map<std::string, std::string>& options) = 0;
449+
447450
protected:
448451
static constexpr auto spmd_device_str = "SPMD:0";
449452

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ class IfrtComputationClient : public ComputationClient {
174174
XLA_ERROR() << __FUNCTION__ << " not implemented";
175175
}
176176

177+
void SetCustomCompileOptions(
178+
const std::unordered_map<std::string, std::string>& options) override {
179+
XLA_ERROR() << __FUNCTION__ << " not implemented";
180+
}
181+
177182
// Creates a new instance of IfrtComputationClient and initializes it.
178183
static absl::StatusOr<absl_nonnull std::unique_ptr<IfrtComputationClient>>
179184
Create();

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,13 +558,18 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
558558

559559
for (auto& instance : instances) {
560560
xla::CompileOptions compile_options;
561+
for (auto& option : custom_compile_options_) {
562+
compile_options.env_option_overrides.push_back(
563+
{option.first, option.second});
564+
}
561565
if (enable_cm_in_mp) {
562566
compile_options.executable_build_options.set_use_spmd_partitioning(true);
563567
compile_options.env_option_overrides.push_back(
564568
{"xla_tpu_decompose_all_gather_einsum", true});
565569
compile_options.env_option_overrides.push_back(
566570
{"xla_tpu_decompose_einsum_reduce_scatter", true});
567571
}
572+
568573
if (instance.is_sharded) {
569574
// TODO(yeounoh) multi-host, multi-slice configurations
570575
compile_options.executable_build_options.set_use_spmd_partitioning(true);
@@ -1093,5 +1098,14 @@ void PjRtComputationClient::OnReadyCallback(
10931098
[callback](absl::Status unused) { callback(); });
10941099
}
10951100

1101+
void PjRtComputationClient::SetCustomCompileOptions(
1102+
const std::unordered_map<std::string, std::string>& options) {
1103+
// Stringfy values
1104+
custom_compile_options_.clear();
1105+
for (const auto& [key, value] : options) {
1106+
custom_compile_options_[key] = value;
1107+
}
1108+
}
1109+
10961110
} // namespace runtime
10971111
} // namespace torch_xla

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ class PjRtComputationClient : public ComputationClient {
172172
void OnReadyCallback(DataPtr data,
173173
const std::function<void()>& callback) override;
174174

175+
void SetCustomCompileOptions(
176+
const std::unordered_map<std::string, std::string>& options) override;
177+
175178
// Creates a new instance of PjRtComputationClient and initializes it.
176179
static absl::StatusOr<absl_nonnull std::unique_ptr<PjRtComputationClient>>
177180
Create();
@@ -204,6 +207,7 @@ class PjRtComputationClient : public ComputationClient {
204207
// If not nullptr, invoke this instead of the actual XLA compilation. Used
205208
// only for testing.
206209
std::function<absl::Status()> fake_xla_compile_ = nullptr;
210+
std::unordered_map<std::string, std::string> custom_compile_options_;
207211

208212
xla::PjRtDevice* StringToPjRtDevice(const std::string& device);
209213

torch_xla/torch_xla.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def compile(
116116
full_graph: Optional[bool] = False,
117117
name: Optional[str] = None,
118118
max_different_graphs: Optional[int] = None,
119+
custom_compile_options: Optional[dict] = None,
119120
):
120121
"""
121122
Optimizes given model/function using torch_xla's LazyTensor tracing mode.
@@ -136,6 +137,8 @@ def compile(
136137
max_different_graphs (Optional[int]): number of different traced graphs of the given
137138
model/function that we are allowed to have. An error will be raised in case this limit
138139
is exceeded.
140+
custom_compile_options (Optional[dict]): A dictionary of custom compile options to be set.
141+
The keys are strings and the values can be of type bool, float, int, or str.
139142
140143
Example::
141144
@@ -214,7 +217,8 @@ def _compile():
214217
sync()
215218
torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status)
216219
torch_xla._XLAC._set_current_graph_name(saved_current_graph_name)
217-
220+
if custom_compile_options is not None and len(custom_compile_options) > 0:
221+
torch_xla._XLAC._set_custom_compile_options(custom_compile_options)
218222
return _compile() if f is None else _compile()(f)
219223

220224

@@ -264,3 +268,17 @@ def launch(
264268
fn(xu.getenv_as(xenv.LOCAL_RANK, int), *args)
265269
else:
266270
xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)
271+
272+
def set_custom_compile_options(
273+
options: Optional[dict] = None,
274+
):
275+
"""Sets custom compile options for the XLA compilation.
276+
277+
Args:
278+
options: A dictionary of custom compile options to be set.
279+
The keys are strings and the values can be of type bool, float, int, or str.
280+
"""
281+
if options is None:
282+
options = {}
283+
torch_xla._XLAC._set_custom_compile_options(options)
284+

0 commit comments

Comments
 (0)