@@ -50,17 +50,19 @@ def forward(self, x):
5050
5151
5252def _tosa_FP_pipeline (module : torch .nn .Module , test_data : input_t1 , dump_file = None ):
53-
54- pipeline = TosaPipelineFP [input_t1 ](module , test_data , [], [])
53+ aten_ops : list [str ] = []
54+ exir_ops : list [str ] = []
55+ pipeline = TosaPipelineFP [input_t1 ](module , test_data , aten_ops , exir_ops )
5556 pipeline .dump_artifact ("to_edge_transform_and_lower" )
5657 pipeline .dump_artifact ("to_edge_transform_and_lower" , suffix = dump_file )
5758 pipeline .pop_stage ("run_method_and_compare_outputs" )
5859 pipeline .run ()
5960
6061
6162def _tosa_INT_pipeline (module : torch .nn .Module , test_data : input_t1 , dump_file = None ):
62-
63- pipeline = TosaPipelineINT [input_t1 ](module , test_data , [], [])
63+ aten_ops : list [str ] = []
64+ exir_ops : list [str ] = []
65+ pipeline = TosaPipelineINT [input_t1 ](module , test_data , aten_ops , exir_ops )
6466 pipeline .dump_artifact ("to_edge_transform_and_lower" )
6567 pipeline .dump_artifact ("to_edge_transform_and_lower" , suffix = dump_file )
6668 pipeline .pop_stage ("run_method_and_compare_outputs" )
@@ -105,11 +107,13 @@ def test_INT_artifact(test_data: input_t1):
105107
106108@common .parametrize ("test_data" , Linear .inputs )
107109def test_numerical_diff_print (test_data : input_t1 ):
110+ aten_ops : list [str ] = []
111+ exir_ops : list [str ] = []
108112 pipeline = TosaPipelineINT [input_t1 ](
109113 Linear (),
110114 test_data ,
111- [] ,
112- [] ,
115+ aten_ops ,
116+ exir_ops ,
113117 custom_path = "diff_print_test" ,
114118 )
115119 pipeline .pop_stage ("run_method_and_compare_outputs" )
@@ -131,7 +135,9 @@ def test_numerical_diff_print(test_data: input_t1):
131135
132136@common .parametrize ("test_data" , Linear .inputs )
133137def test_dump_ops_and_dtypes (test_data : input_t1 ):
134- pipeline = TosaPipelineINT [input_t1 ](Linear (), test_data , [], [])
138+ aten_ops : list [str ] = []
139+ exir_ops : list [str ] = []
140+ pipeline = TosaPipelineINT [input_t1 ](Linear (), test_data , aten_ops , exir_ops )
135141 pipeline .pop_stage ("run_method_and_compare_outputs" )
136142 pipeline .add_stage_after ("quantize" , pipeline .tester .dump_dtype_distribution )
137143 pipeline .add_stage_after ("quantize" , pipeline .tester .dump_operator_distribution )
@@ -149,7 +155,9 @@ def test_dump_ops_and_dtypes(test_data: input_t1):
149155
150156@common .parametrize ("test_data" , Linear .inputs )
151157def test_dump_ops_and_dtypes_parseable (test_data : input_t1 ):
152- pipeline = TosaPipelineINT [input_t1 ](Linear (), test_data , [], [])
158+ aten_ops : list [str ] = []
159+ exir_ops : list [str ] = []
160+ pipeline = TosaPipelineINT [input_t1 ](Linear (), test_data , aten_ops , exir_ops )
153161 pipeline .pop_stage ("run_method_and_compare_outputs" )
154162 pipeline .add_stage_after ("quantize" , pipeline .tester .dump_dtype_distribution , False )
155163 pipeline .add_stage_after (
@@ -177,7 +185,9 @@ def test_collate_tosa_INT_tests(test_data: input_t1):
177185 # Set the environment variable to trigger the collation of TOSA tests
178186 os .environ ["TOSA_TESTCASES_BASE_PATH" ] = "test_collate_tosa_tests"
179187 # Clear out the directory
180- pipeline = TosaPipelineINT [input_t1 ](Linear (), test_data , [], [])
188+ aten_ops : list [str ] = []
189+ exir_ops : list [str ] = []
190+ pipeline = TosaPipelineINT [input_t1 ](Linear (), test_data , aten_ops , exir_ops )
181191 pipeline .pop_stage ("run_method_and_compare_outputs" )
182192 pipeline .run ()
183193
@@ -197,11 +207,13 @@ def test_collate_tosa_INT_tests(test_data: input_t1):
197207@common .parametrize ("test_data" , Linear .inputs )
198208def test_dump_tosa_debug_json (test_data : input_t1 ):
199209 with tempfile .TemporaryDirectory () as tmpdir :
210+ aten_ops : list [str ] = []
211+ exir_ops : list [str ] = []
200212 pipeline = TosaPipelineINT [input_t1 ](
201213 module = Linear (),
202214 test_data = test_data ,
203- aten_op = [] ,
204- exir_op = [] ,
215+ aten_op = aten_ops ,
216+ exir_op = exir_ops ,
205217 custom_path = tmpdir ,
206218 tosa_debug_mode = ArmCompileSpec .DebugMode .JSON ,
207219 )
@@ -228,11 +240,13 @@ def test_dump_tosa_debug_json(test_data: input_t1):
228240@common .parametrize ("test_data" , Linear .inputs )
229241def test_dump_tosa_debug_tosa (test_data : input_t1 ):
230242 with tempfile .TemporaryDirectory () as tmpdir :
243+ aten_ops : list [str ] = []
244+ exir_ops : list [str ] = []
231245 pipeline = TosaPipelineINT [input_t1 ](
232246 module = Linear (),
233247 test_data = test_data ,
234- aten_op = [] ,
235- exir_op = [] ,
248+ aten_op = aten_ops ,
249+ exir_op = exir_ops ,
236250 custom_path = tmpdir ,
237251 tosa_debug_mode = ArmCompileSpec .DebugMode .TOSA ,
238252 )
@@ -248,7 +262,9 @@ def test_dump_tosa_debug_tosa(test_data: input_t1):
248262
249263@common .parametrize ("test_data" , Linear .inputs )
250264def test_dump_tosa_ops (caplog , test_data : input_t1 ):
251- pipeline = TosaPipelineINT [input_t1 ](Linear (), test_data , [], [])
265+ aten_ops : list [str ] = []
266+ exir_ops : list [str ] = []
267+ pipeline = TosaPipelineINT [input_t1 ](Linear (), test_data , aten_ops , exir_ops )
252268 pipeline .pop_stage ("run_method_and_compare_outputs" )
253269 pipeline .dump_operator_distribution ("to_edge_transform_and_lower" )
254270 pipeline .run ()
@@ -267,8 +283,10 @@ def forward(self, x):
267283@common .parametrize ("test_data" , Add .inputs )
268284@common .XfailIfNoCorstone300
269285def test_fail_dump_tosa_ops (caplog , test_data : input_t1 ):
286+ aten_ops : list [str ] = []
287+ exir_ops : list [str ] = []
270288 pipeline = EthosU55PipelineINT [input_t1 ](
271- Add (), test_data , [], [] , use_to_edge_transform_and_lower = True
289+ Add (), test_data , aten_ops , exir_ops , use_to_edge_transform_and_lower = True
272290 )
273291 pipeline .dump_operator_distribution ("to_edge_transform_and_lower" )
274292 pipeline .run ()
0 commit comments