Skip to content

Commit ccd5d81

Browse files
YUNQIUGUOpytorchmergebot
authored andcommitted
[aoti] follow up to use new api in test_provenance_tracing.py (pytorch#149387)
Summary: As title. Follow up of D71181284. and some minor refactoring Context : D69609685 (update test runner to use new api) / pytorch#147105 Test Plan: ``` buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:provenance_tracing -- -r test_triton_kernel_to_post_grad_tracing_cpu ``` Differential Revision: D71375725 Pull Request resolved: pytorch#149387 Approved by: https://github.com/yushangdi
1 parent 5327894 commit ccd5d81

File tree

1 file changed

+98
-120
lines changed

1 file changed

+98
-120
lines changed

test/inductor/test_provenance_tracing.py

Lines changed: 98 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -40,125 +40,29 @@ class TestProvenanceTracingArtifact(TestCase):
4040
corresponding "inductor triton kernel node" is expected.
4141
"""
4242

43-
@requires_cuda
44-
def _check_provenance_tracing_artifact(self, filepath):
43+
def _check_provenance_tracing_artifact(self, filepath, expected_data):
4544
self.assertTrue(filepath.is_dir())
4645
filename = Path(filepath) / "inductor_triton_kernel_to_post_grad_nodes.json"
4746
with open(filename) as f:
4847
actual_data = json.load(f)
4948
# check that the generated provenance tracing artifact is expected
50-
expected_data = {
51-
"triton_poi_fused_mul_0": ["mul"],
52-
"triton_poi_fused_addmm_gelu_1": [
53-
"mul_3",
54-
"mul_1",
55-
"add_tensor",
56-
"add",
57-
"erf",
58-
"mul_2",
59-
],
60-
}
6149
self.assertEqual(sorted(actual_data.items()), sorted(expected_data.items()))
6250

51+
def _check_provenance_tracking_node_mappings(self, filepath, expected_mapping):
52+
self.assertTrue(filepath.is_dir())
6353
filename = Path(filepath) / "inductor_provenance_tracking_node_mappings.json"
6454
with open(filename) as f:
6555
actual_data = json.load(f)
66-
# check that the generated provenance tracing artifact is expected
67-
expected_data = [
68-
(
69-
"cppCodeToPost",
70-
{
71-
"triton_poi_fused_mul_0": ["mul"],
72-
"triton_poi_fused_addmm_gelu_1": [
73-
"mul_3",
74-
"mul_1",
75-
"add_tensor",
76-
"add",
77-
"erf",
78-
"mul_2",
79-
],
80-
},
81-
),
82-
(
83-
"postToCppCode",
84-
{
85-
"mul": ["triton_poi_fused_mul_0"],
86-
"mul_3": ["triton_poi_fused_addmm_gelu_1"],
87-
"mul_1": ["triton_poi_fused_addmm_gelu_1"],
88-
"add_tensor": ["triton_poi_fused_addmm_gelu_1"],
89-
"add": ["triton_poi_fused_addmm_gelu_1"],
90-
"erf": ["triton_poi_fused_addmm_gelu_1"],
91-
"mul_2": ["triton_poi_fused_addmm_gelu_1"],
92-
},
93-
),
94-
(
95-
"postToPre",
96-
{
97-
"mul": ["mul"],
98-
"mm_default": ["addmm"],
99-
"add_tensor": ["addmm"],
100-
"mul_1": ["gelu"],
101-
"mul_2": ["gelu"],
102-
"erf": ["gelu"],
103-
"add": ["gelu"],
104-
"mul_3": ["gelu"],
105-
},
106-
),
107-
(
108-
"preToPost",
109-
{
110-
"mul": ["mul"],
111-
"addmm": ["mm_default", "add_tensor"],
112-
"gelu": ["mul_1", "mul_2", "erf", "add", "mul_3"],
113-
},
114-
),
115-
]
116-
self.assertEqual(sorted(actual_data.items()), sorted(expected_data))
56+
# check that the generated provenance tracing node mapping is expected
57+
self.assertEqual(sorted(actual_data.items()), sorted(expected_mapping))
11758

118-
@requires_cuda
119-
def test_triton_kernel_to_post_grad_tracing(self):
120-
a = torch.randn(10, 20, device="cuda")
121-
b = torch.randn(20, 30, device="cuda")
122-
c = torch.randn(10, 30, device="cuda")
59+
def _test_triton_kernel_to_post_grad_tracing(self, device):
60+
a = torch.randn(10, 20, device=device)
61+
b = torch.randn(20, 30, device=device)
62+
c = torch.randn(10, 30, device=device)
12363
example_inputs = (a, b, c)
12464

12565
model = Model()
126-
for backend in ["aot_inductor", "inductor"]:
127-
try:
128-
with config.patch(
129-
{
130-
"trace.debug_dir": tempfile.mkdtemp(),
131-
"force_disable_caches": True,
132-
}
133-
):
134-
with self.assertLogs(
135-
logging.getLogger("torch._inductor.debug"),
136-
level=logging.WARNING,
137-
) as cm:
138-
if backend == "aot_inductor":
139-
AOTIRunnerUtil.run(model, example_inputs)
140-
else:
141-
ep = torch.export._trace._export(model, example_inputs)
142-
compiled = torch.compile(ep.module(), backend=backend)
143-
compiled(*example_inputs)
144-
self.assertEqual(len(cm.output), 1)
145-
m = re.match(r"WARNING.* debug trace: (.*)", cm.output[0])
146-
self.assertTrue(m)
147-
filepath = Path(m.group(1))
148-
self._check_provenance_tracing_artifact(filepath)
149-
finally:
150-
shutil.rmtree(filepath)
151-
152-
@unittest.skipIf(HAS_GPU, "the test is only for cpu")
153-
def test_triton_kernel_to_post_grad_tracing_cpu(self):
154-
a = torch.randn(10, 20, device="cpu")
155-
b = torch.randn(20, 30, device="cpu")
156-
c = torch.randn(10, 30, device="cpu")
157-
example_inputs = (a, b, c)
158-
159-
model = Model()
160-
ep = torch.export._trace._export(model, example_inputs)
161-
gm = ep.module()
16266
filepath = None
16367

16468
for backend in ["aot_inductor", "inductor"]:
@@ -176,31 +80,105 @@ def test_triton_kernel_to_post_grad_tracing_cpu(self):
17680
if backend == "aot_inductor":
17781
AOTIRunnerUtil.run(model, example_inputs)
17882
else:
179-
compiled = torch.compile(gm, backend=backend)
83+
ep = torch.export._trace._export(model, example_inputs)
84+
compiled = torch.compile(ep.module(), backend=backend)
18085
compiled(*example_inputs)
18186
self.assertEqual(len(cm.output), 1)
18287
m = re.match(r"WARNING.* debug trace: (.*)", cm.output[0])
18388
self.assertTrue(m)
18489
filepath = Path(m.group(1))
185-
filename = (
186-
Path(filepath)
187-
/ "inductor_triton_kernel_to_post_grad_nodes.json"
188-
)
189-
with open(filename) as f:
190-
actual_data = json.load(f)
191-
# check the inductor kernel to post grad nodes mapping is expected for cpu
192-
expected_data = {
193-
"cpp_fused_mul_0": ["mul"],
194-
"cpp_fused_gelu_1": ["mul_3", "mul_1", "add", "erf", "mul_2"],
195-
}
196-
self.assertEqual(
197-
sorted(actual_data.items()), sorted(expected_data.items())
198-
)
90+
if device == "cuda":
91+
expected_data = {
92+
"triton_poi_fused_mul_0": ["mul"],
93+
"triton_poi_fused_addmm_gelu_1": [
94+
"mul_3",
95+
"mul_1",
96+
"add_tensor",
97+
"add",
98+
"erf",
99+
"mul_2",
100+
],
101+
}
102+
self._check_provenance_tracing_artifact(filepath, expected_data)
103+
expected_mapping = [
104+
(
105+
"cppCodeToPost",
106+
{
107+
"triton_poi_fused_mul_0": ["mul"],
108+
"triton_poi_fused_addmm_gelu_1": [
109+
"mul_3",
110+
"mul_1",
111+
"add_tensor",
112+
"add",
113+
"erf",
114+
"mul_2",
115+
],
116+
},
117+
),
118+
(
119+
"postToCppCode",
120+
{
121+
"mul": ["triton_poi_fused_mul_0"],
122+
"mul_3": ["triton_poi_fused_addmm_gelu_1"],
123+
"mul_1": ["triton_poi_fused_addmm_gelu_1"],
124+
"add_tensor": ["triton_poi_fused_addmm_gelu_1"],
125+
"add": ["triton_poi_fused_addmm_gelu_1"],
126+
"erf": ["triton_poi_fused_addmm_gelu_1"],
127+
"mul_2": ["triton_poi_fused_addmm_gelu_1"],
128+
},
129+
),
130+
(
131+
"postToPre",
132+
{
133+
"mul": ["mul"],
134+
"mm_default": ["addmm"],
135+
"add_tensor": ["addmm"],
136+
"mul_1": ["gelu"],
137+
"mul_2": ["gelu"],
138+
"erf": ["gelu"],
139+
"add": ["gelu"],
140+
"mul_3": ["gelu"],
141+
},
142+
),
143+
(
144+
"preToPost",
145+
{
146+
"mul": ["mul"],
147+
"addmm": ["mm_default", "add_tensor"],
148+
"gelu": ["mul_1", "mul_2", "erf", "add", "mul_3"],
149+
},
150+
),
151+
]
152+
self._check_provenance_tracking_node_mappings(
153+
filepath, expected_mapping
154+
)
155+
else:
156+
assert device == "cpu"
157+
# check the inductor kernel to post grad nodes mapping is expected for cpu
158+
expected_data = {
159+
"cpp_fused_mul_0": ["mul"],
160+
"cpp_fused_gelu_1": [
161+
"mul_3",
162+
"mul_1",
163+
"add",
164+
"erf",
165+
"mul_2",
166+
],
167+
}
168+
self._check_provenance_tracing_artifact(filepath, expected_data)
199169

200170
finally:
201171
if filepath:
202172
shutil.rmtree(filepath)
203173

174+
@requires_cuda
175+
def test_triton_kernel_to_post_grad_tracing_cuda(self):
176+
self._test_triton_kernel_to_post_grad_tracing(device="cuda")
177+
178+
@unittest.skipIf(HAS_GPU, "the test is only for cpu")
179+
def test_triton_kernel_to_post_grad_tracing_cpu(self):
180+
self._test_triton_kernel_to_post_grad_tracing(device="cpu")
181+
204182

205183
class TestProvenanceTracingNodeMapping(TestCase):
206184
def test_create_node_mapping(self):

0 commit comments

Comments
 (0)