Skip to content

Commit fb58b61

Browse files
committed
improvment
1 parent fc60c36 commit fb58b61

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

onnx_diagnostic/helpers/helper.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def string_type(
252252
if all(isinstance(k, int) for k in obj):
253253
return f"{{{s}}}"
254254
return f"dict({s})"
255-
# arrat
255+
# array
256256
if isinstance(obj, np.ndarray):
257257
from .onnx_helper import np_dtype_to_tensor_dtype
258258

@@ -286,6 +286,16 @@ def string_type(
286286
return "SymInt"
287287
if isinstance(obj, torch.SymFloat):
288288
return "SymFloat"
289+
if isinstance(obj, torch.export.dynamic_shapes._DimHintType):
290+
if obj == torch.export.dynamic_shapes._DimHintType.DYNAMIC:
291+
return "DYNAMIC"
292+
if obj == torch.export.dynamic_shapes._DimHintType.AUTO:
293+
return "AUTO"
294+
return str(obj)
295+
if obj in (torch.export.Dim.DYNAMIC, torch.export.dynamic_shapes._DimHintType.DYNAMIC):
296+
return "DYNAMIC"
297+
if obj == (torch.export.Dim.AUTO, torch.export.dynamic_shapes._DimHintType.AUTO):
298+
return "AUTO"
289299
# Tensors
290300
if isinstance(obj, torch._subclasses.fake_tensor.FakeTensor):
291301
from .onnx_helper import torch_dtype_to_onnx_dtype
@@ -410,6 +420,10 @@ def string_type(
410420
)
411421

412422
if obj.__class__.__name__ in ("_DimHint", "_DimHintType"):
423+
if obj in (torch.export.Dim.DYNAMIC, torch.export.dynamic_shapes._DimHintType.DYNAMIC):
424+
return "DYNAMIC"
425+
if obj == (torch.export.Dim.AUTO, torch.export.dynamic_shapes._DimHintType.AUTO):
426+
return "AUTO"
413427
return str(obj)
414428

415429
if isinstance(obj, torch.nn.Module):

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,18 @@ def call_torch_export_export(
646646
strict = "nostrict" not in exporter
647647
args, kwargs = split_args_kwargs(data["inputs_export"])
648648
ds = data.get("dynamic_shapes", None)
649+
650+
summary["export_exporter"] = exporter
651+
summary["export_optimization"] = optimization or ""
652+
summary["export_strict"] = strict
653+
summary["export_args"] = string_type(args, with_shape=True)
654+
summary["export_kwargs"] = string_type(kwargs, with_shape=True)
655+
summary["export_dynamic_shapes"] = string_type(ds)
656+
657+
# There is an issue with DynamicShape [[],[]] becomes []
658+
dse = CoupleInputsDynamicShapes(args, kwargs, ds).replace_string_by()
659+
summary["export_dynamic_shapes_export_export"] = string_type(dse)
660+
649661
if verbose:
650662
print(
651663
f"[call_torch_export_export] exporter={exporter!r}, "
@@ -654,18 +666,14 @@ def call_torch_export_export(
654666
print(f"[call_torch_export_export] args={string_type(args, with_shape=True)}")
655667
print(f"[call_torch_export_export] kwargs={string_type(kwargs, with_shape=True)}")
656668
print(f"[call_torch_export_export] dynamic_shapes={_ds_clean(ds)}")
669+
print(f"[call_torch_export_export] dynamic_shapes_export_export={string_type(dse)}")
657670
print("[call_torch_export_export] export...")
658-
summary["export_exporter"] = exporter
659-
summary["export_optimization"] = optimization or ""
660-
summary["export_strict"] = strict
661-
summary["export_args"] = string_type(args, with_shape=True)
662-
summary["export_kwargs"] = string_type(kwargs, with_shape=True)
663671

664672
begin = time.perf_counter()
665673
if quiet:
666674
try:
667675
ep = torch.export.export(
668-
data["model"], args, kwargs=kwargs, dynamic_shapes=ds, strict=strict
676+
data["model"], args, kwargs=kwargs, dynamic_shapes=dse, strict=strict
669677
)
670678
except Exception as e:
671679
summary["ERR_export_export"] = str(e)
@@ -674,7 +682,7 @@ def call_torch_export_export(
674682
return summary, data
675683
else:
676684
ep = torch.export.export(
677-
data["model"], args, kwargs=kwargs, dynamic_shapes=ds, strict=strict
685+
data["model"], args, kwargs=kwargs, dynamic_shapes=dse, strict=strict
678686
)
679687

680688
summary["time_export_export"] = time.perf_counter() - begin

0 commit comments

Comments
 (0)