@@ -40,125 +40,29 @@ class TestProvenanceTracingArtifact(TestCase):
4040 corresponding "inductor triton kernel node" is expected.
4141 """
4242
43- @requires_cuda
44- def _check_provenance_tracing_artifact (self , filepath ):
43+ def _check_provenance_tracing_artifact (self , filepath , expected_data ):
4544 self .assertTrue (filepath .is_dir ())
4645 filename = Path (filepath ) / "inductor_triton_kernel_to_post_grad_nodes.json"
4746 with open (filename ) as f :
4847 actual_data = json .load (f )
4948 # check that the generated provenance tracing artifact is expected
50- expected_data = {
51- "triton_poi_fused_mul_0" : ["mul" ],
52- "triton_poi_fused_addmm_gelu_1" : [
53- "mul_3" ,
54- "mul_1" ,
55- "add_tensor" ,
56- "add" ,
57- "erf" ,
58- "mul_2" ,
59- ],
60- }
6149 self .assertEqual (sorted (actual_data .items ()), sorted (expected_data .items ()))
6250
51+ def _check_provenance_tracking_node_mappings (self , filepath , expected_mapping ):
52+ self .assertTrue (filepath .is_dir ())
6353 filename = Path (filepath ) / "inductor_provenance_tracking_node_mappings.json"
6454 with open (filename ) as f :
6555 actual_data = json .load (f )
66- # check that the generated provenance tracing artifact is expected
67- expected_data = [
68- (
69- "cppCodeToPost" ,
70- {
71- "triton_poi_fused_mul_0" : ["mul" ],
72- "triton_poi_fused_addmm_gelu_1" : [
73- "mul_3" ,
74- "mul_1" ,
75- "add_tensor" ,
76- "add" ,
77- "erf" ,
78- "mul_2" ,
79- ],
80- },
81- ),
82- (
83- "postToCppCode" ,
84- {
85- "mul" : ["triton_poi_fused_mul_0" ],
86- "mul_3" : ["triton_poi_fused_addmm_gelu_1" ],
87- "mul_1" : ["triton_poi_fused_addmm_gelu_1" ],
88- "add_tensor" : ["triton_poi_fused_addmm_gelu_1" ],
89- "add" : ["triton_poi_fused_addmm_gelu_1" ],
90- "erf" : ["triton_poi_fused_addmm_gelu_1" ],
91- "mul_2" : ["triton_poi_fused_addmm_gelu_1" ],
92- },
93- ),
94- (
95- "postToPre" ,
96- {
97- "mul" : ["mul" ],
98- "mm_default" : ["addmm" ],
99- "add_tensor" : ["addmm" ],
100- "mul_1" : ["gelu" ],
101- "mul_2" : ["gelu" ],
102- "erf" : ["gelu" ],
103- "add" : ["gelu" ],
104- "mul_3" : ["gelu" ],
105- },
106- ),
107- (
108- "preToPost" ,
109- {
110- "mul" : ["mul" ],
111- "addmm" : ["mm_default" , "add_tensor" ],
112- "gelu" : ["mul_1" , "mul_2" , "erf" , "add" , "mul_3" ],
113- },
114- ),
115- ]
116- self .assertEqual (sorted (actual_data .items ()), sorted (expected_data ))
56+ # check that the generated provenance tracing node mapping is expected
57+ self .assertEqual (sorted (actual_data .items ()), sorted (expected_mapping ))
11758
118- @requires_cuda
119- def test_triton_kernel_to_post_grad_tracing (self ):
120- a = torch .randn (10 , 20 , device = "cuda" )
121- b = torch .randn (20 , 30 , device = "cuda" )
122- c = torch .randn (10 , 30 , device = "cuda" )
59+ def _test_triton_kernel_to_post_grad_tracing (self , device ):
60+ a = torch .randn (10 , 20 , device = device )
61+ b = torch .randn (20 , 30 , device = device )
62+ c = torch .randn (10 , 30 , device = device )
12363 example_inputs = (a , b , c )
12464
12565 model = Model ()
126- for backend in ["aot_inductor" , "inductor" ]:
127- try :
128- with config .patch (
129- {
130- "trace.debug_dir" : tempfile .mkdtemp (),
131- "force_disable_caches" : True ,
132- }
133- ):
134- with self .assertLogs (
135- logging .getLogger ("torch._inductor.debug" ),
136- level = logging .WARNING ,
137- ) as cm :
138- if backend == "aot_inductor" :
139- AOTIRunnerUtil .run (model , example_inputs )
140- else :
141- ep = torch .export ._trace ._export (model , example_inputs )
142- compiled = torch .compile (ep .module (), backend = backend )
143- compiled (* example_inputs )
144- self .assertEqual (len (cm .output ), 1 )
145- m = re .match (r"WARNING.* debug trace: (.*)" , cm .output [0 ])
146- self .assertTrue (m )
147- filepath = Path (m .group (1 ))
148- self ._check_provenance_tracing_artifact (filepath )
149- finally :
150- shutil .rmtree (filepath )
151-
152- @unittest .skipIf (HAS_GPU , "the test is only for cpu" )
153- def test_triton_kernel_to_post_grad_tracing_cpu (self ):
154- a = torch .randn (10 , 20 , device = "cpu" )
155- b = torch .randn (20 , 30 , device = "cpu" )
156- c = torch .randn (10 , 30 , device = "cpu" )
157- example_inputs = (a , b , c )
158-
159- model = Model ()
160- ep = torch .export ._trace ._export (model , example_inputs )
161- gm = ep .module ()
16266 filepath = None
16367
16468 for backend in ["aot_inductor" , "inductor" ]:
@@ -176,31 +80,105 @@ def test_triton_kernel_to_post_grad_tracing_cpu(self):
17680 if backend == "aot_inductor" :
17781 AOTIRunnerUtil .run (model , example_inputs )
17882 else :
179- compiled = torch .compile (gm , backend = backend )
83+ ep = torch .export ._trace ._export (model , example_inputs )
84+ compiled = torch .compile (ep .module (), backend = backend )
18085 compiled (* example_inputs )
18186 self .assertEqual (len (cm .output ), 1 )
18287 m = re .match (r"WARNING.* debug trace: (.*)" , cm .output [0 ])
18388 self .assertTrue (m )
18489 filepath = Path (m .group (1 ))
185- filename = (
186- Path (filepath )
187- / "inductor_triton_kernel_to_post_grad_nodes.json"
188- )
189- with open (filename ) as f :
190- actual_data = json .load (f )
191- # check the inductor kernel to post grad nodes mapping is expected for cpu
192- expected_data = {
193- "cpp_fused_mul_0" : ["mul" ],
194- "cpp_fused_gelu_1" : ["mul_3" , "mul_1" , "add" , "erf" , "mul_2" ],
195- }
196- self .assertEqual (
197- sorted (actual_data .items ()), sorted (expected_data .items ())
198- )
90+ if device == "cuda" :
91+ expected_data = {
92+ "triton_poi_fused_mul_0" : ["mul" ],
93+ "triton_poi_fused_addmm_gelu_1" : [
94+ "mul_3" ,
95+ "mul_1" ,
96+ "add_tensor" ,
97+ "add" ,
98+ "erf" ,
99+ "mul_2" ,
100+ ],
101+ }
102+ self ._check_provenance_tracing_artifact (filepath , expected_data )
103+ expected_mapping = [
104+ (
105+ "cppCodeToPost" ,
106+ {
107+ "triton_poi_fused_mul_0" : ["mul" ],
108+ "triton_poi_fused_addmm_gelu_1" : [
109+ "mul_3" ,
110+ "mul_1" ,
111+ "add_tensor" ,
112+ "add" ,
113+ "erf" ,
114+ "mul_2" ,
115+ ],
116+ },
117+ ),
118+ (
119+ "postToCppCode" ,
120+ {
121+ "mul" : ["triton_poi_fused_mul_0" ],
122+ "mul_3" : ["triton_poi_fused_addmm_gelu_1" ],
123+ "mul_1" : ["triton_poi_fused_addmm_gelu_1" ],
124+ "add_tensor" : ["triton_poi_fused_addmm_gelu_1" ],
125+ "add" : ["triton_poi_fused_addmm_gelu_1" ],
126+ "erf" : ["triton_poi_fused_addmm_gelu_1" ],
127+ "mul_2" : ["triton_poi_fused_addmm_gelu_1" ],
128+ },
129+ ),
130+ (
131+ "postToPre" ,
132+ {
133+ "mul" : ["mul" ],
134+ "mm_default" : ["addmm" ],
135+ "add_tensor" : ["addmm" ],
136+ "mul_1" : ["gelu" ],
137+ "mul_2" : ["gelu" ],
138+ "erf" : ["gelu" ],
139+ "add" : ["gelu" ],
140+ "mul_3" : ["gelu" ],
141+ },
142+ ),
143+ (
144+ "preToPost" ,
145+ {
146+ "mul" : ["mul" ],
147+ "addmm" : ["mm_default" , "add_tensor" ],
148+ "gelu" : ["mul_1" , "mul_2" , "erf" , "add" , "mul_3" ],
149+ },
150+ ),
151+ ]
152+ self ._check_provenance_tracking_node_mappings (
153+ filepath , expected_mapping
154+ )
155+ else :
156+ assert device == "cpu"
157+ # check the inductor kernel to post grad nodes mapping is expected for cpu
158+ expected_data = {
159+ "cpp_fused_mul_0" : ["mul" ],
160+ "cpp_fused_gelu_1" : [
161+ "mul_3" ,
162+ "mul_1" ,
163+ "add" ,
164+ "erf" ,
165+ "mul_2" ,
166+ ],
167+ }
168+ self ._check_provenance_tracing_artifact (filepath , expected_data )
199169
200170 finally :
201171 if filepath :
202172 shutil .rmtree (filepath )
203173
174+ @requires_cuda
175+ def test_triton_kernel_to_post_grad_tracing_cuda (self ):
176+ self ._test_triton_kernel_to_post_grad_tracing (device = "cuda" )
177+
178+ @unittest .skipIf (HAS_GPU , "the test is only for cpu" )
179+ def test_triton_kernel_to_post_grad_tracing_cpu (self ):
180+ self ._test_triton_kernel_to_post_grad_tracing (device = "cpu" )
181+
204182
205183class TestProvenanceTracingNodeMapping (TestCase ):
206184 def test_create_node_mapping (self ):
0 commit comments