55import sys
66from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
77import numpy as np
8+ import numpy .typing as npt
89from onnx import (
910 AttributeProto ,
10- DataType ,
1111 FunctionProto ,
1212 GraphProto ,
1313 ModelProto ,
@@ -87,7 +87,7 @@ def size_type(dtype: Any) -> int:
8787 raise AssertionError (f"Unexpected dtype={ dtype } " )
8888
8989
90- def tensor_dtype_to_np_dtype (tensor_dtype : DataType ) -> np .dtype :
90+ def tensor_dtype_to_np_dtype (tensor_dtype : int ) -> np .dtype :
9191 """
9292 Converts a TensorProto's data_type to corresponding numpy dtype.
9393 It can be used while making tensor.
@@ -105,7 +105,7 @@ def tensor_dtype_to_np_dtype(tensor_dtype: DataType) -> np.dtype:
105105 f"ml_dtypes can be used."
106106 ) from e
107107
108- mapping = {
108+ mapping : Dict [ int , np . dtype ] = {
109109 TensorProto .BFLOAT16 : ml_dtypes .bfloat16 ,
110110 TensorProto .FLOAT8E4M3FN : ml_dtypes .float8_e4m3fn ,
111111 TensorProto .FLOAT8E4M3FNUZ : ml_dtypes .float8_e4m3fnuz ,
@@ -142,7 +142,30 @@ def string_type(
142142 :showcode:
143143
144144 from onnx_diagnostic.helpers import string_type
145+
145146 print(string_type((1, ["r", 6.6])))
147+
148+ With pytorch:
149+
150+ .. runpython::
151+ :showcode:
152+
153+ import torch
154+ from onnx_diagnostic.helpers import string_type
155+
156+ inputs = (
157+ torch.rand((3, 4), dtype=torch.float16),
158+ [
159+ torch.rand((5, 6), dtype=torch.float16),
160+ torch.rand((5, 6, 7), dtype=torch.float16),
161+ ]
162+ )
163+
164+ # with shapes
165+ print(string_type(inputs, with_shape=True))
166+
167+ # with min max
168+ print(string_type(inputs, with_shape=True, with_min_max=True))
146169 """
147170 if obj is None :
148171 return "None"
@@ -465,7 +488,19 @@ def string_sig(f: Callable, kwargs: Optional[Dict[str, Any]] = None) -> str:
465488
466489@functools .cache
467490def onnx_dtype_name (itype : int ) -> str :
468- """Returns the ONNX name for a specific element type."""
491+ """
492+ Returns the ONNX name for a specific element type.
493+
494+ .. runpython::
495+ :showcode:
496+
497+ import onnx
498+ from onnx_diagnostic.helpers import onnx_dtype_name
499+
500+ itype = onnx.TensorProto.BFLOAT16
501+ print(onnx_dtype_name(itype))
502+ print(onnx_dtype_name(7))
503+ """
469504 for k in dir (TensorProto ):
470505 v = getattr (TensorProto , k )
471506 if v == itype :
@@ -477,19 +512,24 @@ def pretty_onnx(
477512 onx : Union [FunctionProto , GraphProto , ModelProto , ValueInfoProto , str ],
478513 with_attributes : bool = False ,
479514 highlight : Optional [Set [str ]] = None ,
515+ shape_inference : bool = False ,
480516) -> str :
481517 """
482518 Displays an onnx prot in a better way.
483519
484520 :param with_attributes: displays attributes as well, if only a node is printed
485521 :param highlight: to highlight some names
522+ :param shape_inference: run shape inference before printing the model
486523 :return: text
487524 """
488525 assert onx is not None , "onx cannot be None"
489526 if isinstance (onx , str ):
490527 onx = onnx_load (onx , load_external_data = False )
491528 assert onx is not None , "onx cannot be None"
492529
530+ if shape_inference :
531+ onx = onx .shape_inference .infer_shapes (onx )
532+
493533 if isinstance (onx , ValueInfoProto ):
494534 name = onx .name
495535 itype = onx .type .tensor_type .elem_type
@@ -577,7 +617,7 @@ def make_hash(obj: Any) -> str:
577617
578618def get_onnx_signature (model : ModelProto ) -> Tuple [Tuple [str , Any ], ...]:
579619 """
580- Produces a tuple of tuples correspinding to the signatures.
620+ Produces a tuple of tuples corresponding to the signatures.
581621
582622 :param model: model
583623 :return: signature
@@ -611,7 +651,7 @@ def convert_endian(tensor: TensorProto) -> None:
611651 tensor .raw_data = np .frombuffer (tensor .raw_data , dtype = np_dtype ).byteswap ().tobytes ()
612652
613653
614- def from_array_ml_dtypes (arr : np . ndarray , name : Optional [str ] = None ) -> TensorProto :
654+ def from_array_ml_dtypes (arr : npt . ArrayLike , name : Optional [str ] = None ) -> TensorProto :
615655 """
616656 Converts a numpy array to a tensor def assuming the dtype
617657 is defined in ml_dtypes.
@@ -625,7 +665,7 @@ def from_array_ml_dtypes(arr: np.ndarray, name: Optional[str] = None) -> TensorP
625665 """
626666 import ml_dtypes
627667
628- assert isinstance (arr , np .ndarray ), f"arr must be of type np .ndarray, got { type (arr )} "
668+ assert isinstance (arr , np .ndarray ), f"arr must be of type numpy .ndarray, got { type (arr )} "
629669
630670 tensor = TensorProto ()
631671 tensor .dims .extend (arr .shape )
@@ -651,9 +691,9 @@ def from_array_ml_dtypes(arr: np.ndarray, name: Optional[str] = None) -> TensorP
651691 return tensor
652692
653693
654- def from_array_extended (tensor : np . ndarray , name : Optional [str ] = None ) -> TensorProto :
694+ def from_array_extended (tensor : npt . ArrayLike , name : Optional [str ] = None ) -> TensorProto :
655695 """
656- Converts an array into a TensorProto.
696+ Converts an array into a :class:`onnx. TensorProto` .
657697
658698 :param tensor: numpy array
659699 :param name: name
0 commit comments