Skip to content

Commit 0df1c92

Browse files
committed
fix one case where SymInt cannot be printed
1 parent 4634489 commit 0df1c92

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.8
55
+++++
66

7+
* :pr:`372`: fix patch on rotary embedding
78
* :pr:`371`: fix make_fake_with_dynamic_dimensions
89

910
0.8.7

onnx_diagnostic/helpers/helper.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,9 +704,35 @@ def string_type(
704704
if obj.__class__.__name__ == "VirtualTensor":
705705
if verbose:
706706
print(f"[string_type] TT4:{type(obj)}")
707+
708+
def _torch_sym_int_to_str(value: "torch.SymInt") -> Union[int, str]: # noqa: F821
709+
if isinstance(value, str):
710+
return value
711+
if hasattr(value, "node") and isinstance(value.node, str):
712+
return f"{value.node}"
713+
714+
from torch.fx.experimental.sym_node import SymNode
715+
716+
if hasattr(value, "node") and isinstance(value.node, SymNode):
717+
# '_expr' is safer than expr
718+
return str(value.node._expr).replace(" ", "")
719+
720+
try:
721+
val_int = int(value)
722+
return val_int
723+
except (
724+
TypeError,
725+
ValueError,
726+
AttributeError,
727+
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode,
728+
):
729+
pass
730+
731+
raise AssertionError(f"Unable to convert {value!r} into string")
732+
707733
return (
708734
f"{obj.__class__.__name__}(name={obj.name!r}, "
709-
f"dtype={obj.dtype}, shape={obj.shape})"
735+
f"dtype={obj.dtype}, shape={tuple(_torch_sym_int_to_str(_) for _ in obj.shape)})"
710736
)
711737

712738
if obj.__class__.__name__ == "KeyValuesWrapper":

0 commit comments

Comments
 (0)