@@ -1012,9 +1012,12 @@ def call_torch_export_onnx(
10121012 def _os_ort_optim (epo ):
10131013 onnxscript .optimizer .optimize_ir (epo .model )
10141014 optimized = ort_fusions .optimize_for_ort (epo .model )
1015- epo .model = (
1016- optimized if isinstance (optimized , onnxscript .ir .Model ) else optimized [0 ]
1017- )
1015+ if isinstance (optimized , tuple ):
1016+ for k , v in optimized [1 ].items ():
1017+ summary [f"op_opt_fused_{ k } " ] = v
1018+ epo .model = optimized [0 ]
1019+ else :
1020+ epo .model = optimized
10181021
10191022 label , f_optim = "export_onnx_opt_os_ort" , (lambda epo = epo : _os_ort_optim (epo ))
10201023 _quiet_or_not_quiet (quiet , label , summary , data , f_optim )
@@ -1199,7 +1202,29 @@ def call_torch_export_custom(
11991202 print ("[call_torch_export_custom] done (export)" )
12001203
12011204 if os_ort :
1202- pass
1205+ if verbose :
1206+ print ("[call_torch_export_custom] conversion to IR..." )
1207+ begin = time .perf_counter ()
1208+ ir_model = epo .to_ir ()
1209+ duration = time .perf_counter () - begin
1210+ summary ["time_optim_to_ir" ] = duration
1211+ if verbose :
1212+ print (f"[call_torch_export_custom] done in { duration } " )
1213+ print ("[call_torch_export_custom] start optimization..." )
1214+ begin = time .perf_counter ()
1215+ onnxscript .optimizer .optimize_ir (ir_model )
1216+ ir_optimized = ort_fusions .optimize_for_ort (ir_model )
1217+ if isinstance (ir_optimized , tuple ):
1218+ report = ir_optimized [1 ]
1219+ for k , v in report .items ():
1220+ summary [f"op_opt_fused_{ k } " ] = v
1221+ ir_optimized = ir_optimized [0 ]
1222+ epo .model = ir_optimized
1223+ duration = time .perf_counter () - begin
1224+ summary ["time_optim_os_ort" ] = duration
1225+ if verbose :
1226+ print (f"[call_torch_export_custom] done in { duration } " )
1227+
12031228 data ["onnx_program" ] = epo
12041229 return summary , data
12051230
0 commit comments