Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3306,6 +3306,19 @@ void InitXlaModuleBindings(py::module m) {
XLA_ERROR() << "Could not get the buffer pointer for XLATensor "
"without a data handle or an IR.";
})
.def("_set_custom_compile_options",
[](const py::dict& compile_options) {
std::unordered_map<std::string, std::string> options;
for (const auto& item : compile_options) {
// Keys must be strings; values are stringified.
const std::string key = py::str(item.first);
options[key] = py::str(item.second);
}
XLA_ASSIGN_OR_THROW(
runtime::ComputationClient * absl_nonnull client,
runtime::GetComputationClient());
client->SetCustomCompileOptions(options);
})
.def(
// from an XLA tensor to a PyCapsule.
// When consuming the PyCapsule, we should synchronize
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,14 @@ class ComputationClient {
// after the last ':' character of the device string.
static int64_t GetDeviceOrdinal(const std::string& device);

// Sets XLA compile option overrides used by the backend compiler.
// - The map keys are XLA compiler flag names (env option override keys).
// - The values are stringified flag values.
// - Calling this method **overwrites** any previously set options.
// (Pass an empty map to clear.)
virtual void SetCustomCompileOptions(
const std::unordered_map<std::string, std::string>& options) = 0;

protected:
static constexpr auto spmd_device_str = "SPMD:0";

Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ class IfrtComputationClient : public ComputationClient {
XLA_ERROR() << __FUNCTION__ << " not implemented";
}

void SetCustomCompileOptions(
const std::unordered_map<std::string, std::string>& options) override {
XLA_ERROR() << __FUNCTION__ << " not implemented";
}

// Creates a new instance of IfrtComputationClient and initializes it.
static absl::StatusOr<absl_nonnull std::unique_ptr<IfrtComputationClient>>
Create();
Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,13 +554,17 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(

for (auto& instance : instances) {
xla::CompileOptions compile_options;
for (const auto& [name, value] : custom_compile_options_) {
compile_options.env_option_overrides.push_back({name, value});
}
if (enable_cm_in_mp) {
compile_options.executable_build_options.set_use_spmd_partitioning(true);
compile_options.env_option_overrides.push_back(
{"xla_tpu_decompose_all_gather_einsum", true});
compile_options.env_option_overrides.push_back(
{"xla_tpu_decompose_einsum_reduce_scatter", true});
}

if (instance.is_sharded) {
// TODO(yeounoh) multi-host, multi-slice configurations
compile_options.executable_build_options.set_use_spmd_partitioning(true);
Expand Down Expand Up @@ -1052,5 +1056,13 @@ void PjRtComputationClient::OnReadyCallback(
[callback](absl::Status unused) { callback(); });
}

void PjRtComputationClient::SetCustomCompileOptions(
const std::unordered_map<std::string, std::string>& options) {
custom_compile_options_.clear();
for (const auto& [key, value] : options) {
custom_compile_options_[key] = value;
}
}

} // namespace runtime
} // namespace torch_xla
5 changes: 5 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ class PjRtComputationClient : public ComputationClient {
void OnReadyCallback(DataPtr data,
const std::function<void()>& callback) override;

// See base class for semantics. This call overwrites previously set options.
void SetCustomCompileOptions(
const std::unordered_map<std::string, std::string>& options) override;

// Creates a new instance of PjRtComputationClient and initializes it.
static absl::StatusOr<absl_nonnull std::unique_ptr<PjRtComputationClient>>
Create();
Expand Down Expand Up @@ -197,6 +201,7 @@ class PjRtComputationClient : public ComputationClient {
// If not nullptr, invoke this instead of the actual XLA compilation. Used
// only for testing.
std::function<absl::Status()> fake_xla_compile_ = nullptr;
std::unordered_map<std::string, std::string> custom_compile_options_;

xla::PjRtDevice* StringToPjRtDevice(const std::string& device);

Expand Down
19 changes: 19 additions & 0 deletions torch_xla/torch_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def compile(
full_graph: Optional[bool] = False,
name: Optional[str] = None,
max_different_graphs: Optional[int] = None,
custom_compile_options: Optional[dict[str, Any]] = None,
):
"""
Optimizes given model/function using torch_xla's LazyTensor tracing mode.
Expand All @@ -136,6 +137,11 @@ def compile(
max_different_graphs (Optional[int]): number of different traced graphs of the given
model/function that we are allowed to have. An error will be raised in case this limit
is exceeded.
custom_compile_options (Optional[dict[str, Any]]): XLA compiler flag overrides.
Keys are XLA compiler flag names (forwarded to xla::CompileOptions.env_option_overrides),
and values may be bool, int, float, or str (internally stringified).
- {} (empty dict): clear previously set options.
- None (default): do not change previously set options (no-op).
Example::
Expand Down Expand Up @@ -215,6 +221,8 @@ def _compile():
torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status)
torch_xla._XLAC._set_current_graph_name(saved_current_graph_name)

if custom_compile_options is not None:
torch_xla._XLAC._set_custom_compile_options(custom_compile_options)
return _compile() if f is None else _compile()(f)


Expand Down Expand Up @@ -264,3 +272,14 @@ def launch(
fn(xu.getenv_as(xenv.LOCAL_RANK, int), *args)
else:
xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)


def set_custom_compile_options(options: dict[str, Any]) -> None:
"""Set XLA **compiler flag overrides** (env option overrides) for compilation.
Args:
options: Dict mapping XLA flag names to values. Values may be bool/float/int/str;
they will be stringified before being passed to XLA.
Pass an empty dict `{}` to clear previously set options.
"""
torch_xla._XLAC._set_custom_compile_options(options)