diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a52ecc8124e..66657e4a254 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -3306,6 +3306,19 @@ void InitXlaModuleBindings(py::module m) { XLA_ERROR() << "Could not get the buffer pointer for XLATensor " "without a data handle or an IR."; }) + .def("_set_custom_compile_options", + [](const py::dict& compile_options) { + std::unordered_map options; + for (const auto& item : compile_options) { + // Keys must be strings; values are stringified. + const std::string key = py::str(item.first); + options[key] = py::str(item.second); + } + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull client, + runtime::GetComputationClient()); + client->SetCustomCompileOptions(options); + }) .def( // from an XLA tensor to a PyCapsule. // When consuming the PyCapsule, we should synchronize diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 79ff199eb2f..364641a8869 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -446,6 +446,14 @@ class ComputationClient { // after the last ':' character of the device string. static int64_t GetDeviceOrdinal(const std::string& device); + // Sets XLA compile option overrides used by the backend compiler. + // - The map keys are XLA compiler flag names (env option override keys). + // - The values are stringified flag values. + // - Calling this method **overwrites** any previously set options. + // (Pass an empty map to clear.) + virtual void SetCustomCompileOptions( + const std::unordered_map& options) = 0; + protected: static constexpr auto spmd_device_str = "SPMD:0"; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 8b45922c397..6e787545710 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -172,6 +172,11 @@ class IfrtComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; } + void SetCustomCompileOptions( + const std::unordered_map& options) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + // Creates a new instance of IfrtComputationClient and initializes it. static absl::StatusOr> Create(); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index 280b50964d8..908af49cdb9 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -554,6 +554,9 @@ std::vector PjRtComputationClient::Compile( for (auto& instance : instances) { xla::CompileOptions compile_options; + for (const auto& [name, value] : custom_compile_options_) { + compile_options.env_option_overrides.push_back({name, value}); + } if (enable_cm_in_mp) { compile_options.executable_build_options.set_use_spmd_partitioning(true); compile_options.env_option_overrides.push_back( @@ -561,6 +564,7 @@ std::vector PjRtComputationClient::Compile( compile_options.env_option_overrides.push_back( {"xla_tpu_decompose_einsum_reduce_scatter", true}); } + if (instance.is_sharded) { // TODO(yeounoh) multi-host, multi-slice configurations compile_options.executable_build_options.set_use_spmd_partitioning(true); @@ -1052,5 +1056,13 @@ void PjRtComputationClient::OnReadyCallback( [callback](absl::Status unused) { callback(); }); } +void PjRtComputationClient::SetCustomCompileOptions( + const std::unordered_map& options) { + custom_compile_options_.clear(); + for (const auto& [key, value] : options) { + custom_compile_options_[key] = value; + } +} + } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index d550f1cce0c..fc516a7042c 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -165,6 +165,10 @@ class PjRtComputationClient : public ComputationClient { void OnReadyCallback(DataPtr data, const std::function& callback) override; + // See base class for semantics. This call overwrites previously set options. + void SetCustomCompileOptions( + const std::unordered_map& options) override; + // Creates a new instance of PjRtComputationClient and initializes it. static absl::StatusOr> Create(); @@ -197,6 +201,7 @@ class PjRtComputationClient : public ComputationClient { // If not nullptr, invoke this instead of the actual XLA compilation. Used // only for testing. std::function fake_xla_compile_ = nullptr; + std::unordered_map custom_compile_options_; xla::PjRtDevice* StringToPjRtDevice(const std::string& device); diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index 9062d6a9ef2..76f209f008b 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -116,6 +116,7 @@ def compile( full_graph: Optional[bool] = False, name: Optional[str] = None, max_different_graphs: Optional[int] = None, + custom_compile_options: Optional[dict[str, Any]] = None, ): """ Optimizes given model/function using torch_xla's LazyTensor tracing mode. @@ -136,6 +137,11 @@ def compile( max_different_graphs (Optional[int]): number of different traced graphs of the given model/function that we are allowed to have. An error will be raised in case this limit is exceeded. + custom_compile_options (Optional[dict[str, Any]]): XLA compiler flag overrides. + Keys are XLA compiler flag names (forwarded to xla::CompileOptions.env_option_overrides), + and values may be bool, int, float, or str (internally stringified). + - {} (empty dict): clear previously set options. + - None (default): do not change previously set options (no-op). Example:: @@ -215,6 +221,8 @@ def _compile(): torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status) torch_xla._XLAC._set_current_graph_name(saved_current_graph_name) + if custom_compile_options is not None: + torch_xla._XLAC._set_custom_compile_options(custom_compile_options) return _compile() if f is None else _compile()(f) @@ -264,3 +272,14 @@ def launch( fn(xu.getenv_as(xenv.LOCAL_RANK, int), *args) else: xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method) + + +def set_custom_compile_options(options: dict[str, Any]) -> None: + """Set XLA **compiler flag overrides** (env option overrides) for compilation. + + Args: + options: Dict mapping XLA flag names to values. Values may be bool/float/int/str; + they will be stringified before being passed to XLA. + Pass an empty dict `{}` to clear previously set options. + """ + torch_xla._XLAC._set_custom_compile_options(options)