Skip to content

Commit a5d3555

Browse files
committed
fix issues
1 parent 194ffc2 commit a5d3555

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,9 +1012,12 @@ def call_torch_export_onnx(
10121012
def _os_ort_optim(epo):
10131013
onnxscript.optimizer.optimize_ir(epo.model)
10141014
optimized = ort_fusions.optimize_for_ort(epo.model)
1015-
epo.model = (
1016-
optimized if isinstance(optimized, onnxscript.ir.Model) else optimized[0]
1017-
)
1015+
if isinstance(optimized, tuple):
1016+
for k, v in optimized[1].items():
1017+
summary[f"op_opt_fused_{k}"] = v
1018+
epo.model = optimized[0]
1019+
else:
1020+
epo.model = optimized
10181021

10191022
label, f_optim = "export_onnx_opt_os_ort", (lambda epo=epo: _os_ort_optim(epo))
10201023
_quiet_or_not_quiet(quiet, label, summary, data, f_optim)
@@ -1199,7 +1202,29 @@ def call_torch_export_custom(
11991202
print("[call_torch_export_custom] done (export)")
12001203

12011204
if os_ort:
1202-
pass
1205+
if verbose:
1206+
print("[call_torch_export_custom] conversion to IR...")
1207+
begin = time.perf_counter()
1208+
ir_model = epo.to_ir()
1209+
duration = time.perf_counter() - begin
1210+
summary["time_optim_to_ir"] = duration
1211+
if verbose:
1212+
print(f"[call_torch_export_custom] done in {duration}")
1213+
print("[call_torch_export_custom] start optimization...")
1214+
begin = time.perf_counter()
1215+
onnxscript.optimizer.optimize_ir(ir_model)
1216+
ir_optimized = ort_fusions.optimize_for_ort(ir_model)
1217+
if isinstance(ir_optimized, tuple):
1218+
report = ir_optimized[1]
1219+
for k, v in report.items():
1220+
summary[f"op_opt_fused_{k}"] = v
1221+
ir_optimized = ir_optimized[0]
1222+
epo.model = ir_optimized
1223+
duration = time.perf_counter() - begin
1224+
summary["time_optim_os_ort"] = duration
1225+
if verbose:
1226+
print(f"[call_torch_export_custom] done in {duration}")
1227+
12031228
data["onnx_program"] = epo
12041229
return summary, data
12051230

0 commit comments

Comments
 (0)