Skip to content

Commit 9a517a9

Browse files
authored
enable_python_dispatcher in some XLA custom passes. (#9312)
1 parent 728043e commit 9a517a9

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torch_xla/experimental/unbounded_dynamism_export.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch.fx import Graph, GraphModule, subgraph_rewriter
99
from torch.utils import _pytree as pytree
1010
from torch.utils._pytree import tree_map
11+
from torch._dispatch.python import enable_python_dispatcher
1112

1213
aten = torch.ops.aten
1314

@@ -30,7 +31,8 @@ def call_function(
3031
) -> torch.fx.Node:
3132
node = graph.call_function(target, args, kwargs)
3233
_, args, kwargs = get_fake_args_kwargs(node)
33-
node.meta["val"] = target(*args, **kwargs)
34+
with enable_python_dispatcher():
35+
node.meta["val"] = target(*args, **kwargs)
3436
return node
3537

3638

0 commit comments

Comments
 (0)