@@ -142,7 +142,30 @@ def string_type(
142142 :showcode:
143143
144144 from onnx_diagnostic.helpers import string_type
145+
145146 print(string_type((1, ["r", 6.6])))
147+
148+ With pytorch:
149+
150+ .. runpython::
151+ :showcode:
152+
153+ import torch
154+ from onnx_diagnostic.helpers import string_type
155+
156+ inputs = (
157+ torch.rand((3, 4), dtype=torch.float16),
158+ [
159+ torch.rand((5, 6), dtype=torch.float16),
160+ torch.rand((5, 6, 7), dtype=torch.float16),
161+ ]
162+ )
163+
164+ # with shapes
165+ print(string_type(inputs, with_shape=True))
166+
167+ # with min max
168+ print(string_type(inputs, with_shape=True, with_min_max=True))
146169 """
147170 if obj is None :
148171 return "None"
@@ -465,7 +488,19 @@ def string_sig(f: Callable, kwargs: Optional[Dict[str, Any]] = None) -> str:
465488
466489@functools .cache
467490def onnx_dtype_name (itype : int ) -> str :
468- """Returns the ONNX name for a specific element type."""
491+ """
492+ Returns the ONNX name for a specific element type.
493+
494+ .. runpython::
495+ :showcode:
496+
497+ import onnx
498+ from onnx_diagnostic.helpers import onnx_dtype_name
499+
500+ itype = onnx.TensorProto.BFLOAT16
501+ print(onnx_dtype_name(itype))
502+ print(onnx_dtype_name(7))
503+ """
469504 for k in dir (TensorProto ):
470505 v = getattr (TensorProto , k )
471506 if v == itype :
@@ -477,19 +512,24 @@ def pretty_onnx(
477512 onx : Union [FunctionProto , GraphProto , ModelProto , ValueInfoProto , str ],
478513 with_attributes : bool = False ,
479514 highlight : Optional [Set [str ]] = None ,
515+ shape_inference : bool = False ,
480516) -> str :
481517 """
482518 Displays an onnx prot in a better way.
483519
484520 :param with_attributes: displays attributes as well, if only a node is printed
485521 :param highlight: to highlight some names
522+ :param shape_inference: run shape inference before printing the model
486523 :return: text
487524 """
488525 assert onx is not None , "onx cannot be None"
489526 if isinstance (onx , str ):
490527 onx = onnx_load (onx , load_external_data = False )
491528 assert onx is not None , "onx cannot be None"
492529
530+ if shape_inference :
531+ onx = onx .shape_inference .infer_shapes (onx )
532+
493533 if isinstance (onx , ValueInfoProto ):
494534 name = onx .name
495535 itype = onx .type .tensor_type .elem_type
@@ -577,7 +617,7 @@ def make_hash(obj: Any) -> str:
577617
578618def get_onnx_signature (model : ModelProto ) -> Tuple [Tuple [str , Any ], ...]:
579619 """
580- Produces a tuple of tuples correspinding to the signatures.
620+ Produces a tuple of tuples corresponding to the signatures.
581621
582622 :param model: model
583623 :return: signature
0 commit comments