@@ -116,6 +116,7 @@ def compile(
116
116
full_graph : Optional [bool ] = False ,
117
117
name : Optional [str ] = None ,
118
118
max_different_graphs : Optional [int ] = None ,
119
+ custom_compile_options : Optional [dict ] = None ,
119
120
):
120
121
"""
121
122
Optimizes given model/function using torch_xla's LazyTensor tracing mode.
@@ -136,6 +137,8 @@ def compile(
136
137
max_different_graphs (Optional[int]): number of different traced graphs of the given
137
138
model/function that we are allowed to have. An error will be raised in case this limit
138
139
is exceeded.
140
+ custom_compile_options (Optional[dict]): A dictionary of custom compile options to be set.
141
+ The keys are strings and the values can be of type bool, float, int, or str.
139
142
140
143
Example::
141
144
@@ -214,7 +217,8 @@ def _compile():
214
217
sync ()
215
218
torch_xla ._XLAC ._set_use_eager_mode (saved_eager_mode_status )
216
219
torch_xla ._XLAC ._set_current_graph_name (saved_current_graph_name )
217
-
220
+ if custom_compile_options is not None and len (custom_compile_options ) > 0 :
221
+ torch_xla ._XLAC ._set_custom_compile_options (custom_compile_options )
218
222
return _compile () if f is None else _compile ()(f )
219
223
220
224
@@ -264,3 +268,17 @@ def launch(
264
268
fn (xu .getenv_as (xenv .LOCAL_RANK , int ), * args )
265
269
else :
266
270
xmp .spawn (fn , args = args , nprocs = nprocs , start_method = start_method )
271
+
272
+ def set_custom_compile_options (
273
+ options : Optional [dict ] = None ,
274
+ ):
275
+ """Sets custom compile options for the XLA compilation.
276
+
277
+ Args:
278
+ options: A dictionary of custom compile options to be set.
279
+ The keys are strings and the values can be of type bool, float, int, or str.
280
+ """
281
+ if options is None :
282
+ options = {}
283
+ torch_xla ._XLAC ._set_custom_compile_options (options )
284
+
0 commit comments