Skip to content

Commit 69e41ce

Browse files
ydwu4pytorchmergebot
authored andcommitted
move find_hop_schema into _higher_order_ops/schema.py (pytorch#151147)
Pull Request resolved: pytorch#151147 Approved by: https://github.com/zou3519
1 parent 5acc3e2 commit 69e41ce

File tree

2 files changed

+29
-26
lines changed

2 files changed

+29
-26
lines changed

test/dynamo/test_base_hop.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Owner(s): ["module: dynamo"]
22
import unittest
3-
from typing import Any
43

54
import torch
65
import torch._dynamo.test_case
@@ -11,6 +10,7 @@
1110
EagerAndRecordGraphs,
1211
normalize_gm,
1312
)
13+
from torch._higher_order_ops.schema import find_hop_schema
1414
from torch.testing._internal.inductor_utils import HAS_CUDA
1515

1616

@@ -74,29 +74,6 @@ def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"):
7474
""", # NOQA: B950
7575
)
7676

77-
def _find_hop_schema(
78-
self, gm: torch.fx.GraphModule, target: Any
79-
) -> list[torch._C.FunctionSchema]:
80-
import torch.utils._pytree as pytree
81-
82-
schemas = []
83-
for node in gm.graph.find_nodes(op="call_function", target=target):
84-
85-
def _get_example_value(node: torch.fx.Node) -> Any:
86-
if node.op == "get_attr":
87-
return getattr(gm, node.target)
88-
else:
89-
return node.meta["example_value"]
90-
91-
fake_args, fake_kwargs = pytree.tree_map_only(
92-
torch.fx.Node,
93-
_get_example_value,
94-
(node.args, node.kwargs),
95-
)
96-
schema = node.target.gen_schema(*fake_args, **fake_kwargs)
97-
schemas.append(schema)
98-
return schemas
99-
10077
def test_schema_gen_single_return(self):
10178
def inner(x, y):
10279
return (x @ y).sin().cos()
@@ -112,7 +89,7 @@ def f(x, y):
11289

11390
out = f(x.clone(), y)
11491
self.assertEqual(out, inner(x.clone(), y))
115-
schemas = self._find_hop_schema(backend.graphs[0], invoke_quant_test)
92+
schemas = find_hop_schema(backend.graphs[0], invoke_quant_test)
11693
self.assertEqual(len(schemas), 1)
11794
self.assertExpectedInline(
11895
str(schemas[0]),
@@ -140,7 +117,7 @@ def f(x, y):
140117

141118
out = f(x.clone(), y)
142119
self.assertEqual(out, inner([x.clone(), y]))
143-
schemas = self._find_hop_schema(backend.graphs[0], invoke_quant_test)
120+
schemas = find_hop_schema(backend.graphs[0], invoke_quant_test)
144121
self.assertEqual(len(schemas), 1)
145122
self.assertExpectedInline(
146123
str(schemas[0]),

torch/_higher_order_ops/schema.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, Optional
33

44
import torch
5+
from torch.fx.node import Target
56

67

78
# Below is an implementation of generating FunctionSchema from example values.
@@ -152,3 +153,28 @@ def from_hop_argument_info(
152153
False,
153154
False,
154155
)
156+
157+
158+
def find_hop_schema(
159+
gm: torch.fx.GraphModule, target: Target
160+
) -> list[torch._C.FunctionSchema]:
161+
import torch.utils._pytree as pytree
162+
163+
schemas = []
164+
for node in gm.graph.find_nodes(op="call_function", target=target):
165+
166+
def _get_example_value(node: torch.fx.Node) -> Any:
167+
if node.op == "get_attr":
168+
assert isinstance(node.target, str)
169+
return getattr(gm, node.target)
170+
else:
171+
return node.meta["example_value"]
172+
173+
fake_args, fake_kwargs = pytree.tree_map_only(
174+
torch.fx.Node,
175+
_get_example_value,
176+
(node.args, node.kwargs),
177+
)
178+
schema = node.target.gen_schema(*fake_args, **fake_kwargs)
179+
schemas.append(schema)
180+
return schemas

0 commit comments

Comments
 (0)