Skip to content

Commit 69acc8d

Browse files
authored
First PR (#1)
* changes * clean * fix doc * doc * clean * fix issues * fix issues * mypy * doc
1 parent e6ac972 commit 69acc8d

File tree

12 files changed

+164
-28
lines changed

12 files changed

+164
-28
lines changed

.github/workflows/documentation.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ jobs:
8282
run: cat doc.txt
8383

8484
- name: Check for errors and warnings
85-
continue-on-error: true
8685
run: |
8786
if [[ $(grep ERROR doc.txt | grep -v 'Unknown target name: "l_shape"' | grep -v 'Unknown target name: "l_x"') ]]; then
8887
echo "Documentation produces errors."

.github/workflows/mypy.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
name: Type annotation with mypy
2+
on: [push, pull_request]
3+
jobs:
4+
mypy:
5+
runs-on: ubuntu-latest
6+
steps:
7+
- uses: actions/checkout@v3
8+
- uses: actions/setup-python@v4
9+
with:
10+
python-version: '3.12'
11+
- name: Install mypy
12+
run: pip install mypy
13+
- name: Run mypy
14+
run: mypy

README.rst

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,51 @@ or
4444

4545
pip install onnx-diagnostic
4646

47+
Snapshot of usefuls tools
48+
+++++++++++++++++++++++++
49+
50+
**string_type**
51+
52+
.. code-block:: python
53+
54+
import torch
55+
from onnx_diagnostic.helpers import string_type
56+
57+
inputs = (
58+
torch.rand((3, 4), dtype=torch.float16),
59+
[
60+
torch.rand((5, 6), dtype=torch.float16),
61+
torch.rand((5, 6, 7), dtype=torch.float16),
62+
]
63+
)
64+
65+
# with shapes
66+
print(string_type(inputs, with_shape=True))
67+
68+
::
69+
70+
>>> (T10s3x4,#2[T10s5x6,T10s5x6x7])
71+
72+
**onnx_dtype_name**
73+
74+
.. code-block:: python
75+
76+
import onnx
77+
from onnx_diagnostic.helpers import onnx_dtype_name
78+
79+
itype = onnx.TensorProto.BFLOAT16
80+
print(onnx_dtype_name(itype))
81+
print(onnx_dtype_name(7))
82+
83+
::
84+
85+
>>> BFLOAT16
86+
>>> INT64
87+
88+
**max_diff**
89+
90+
Returns the maximum discrancies accross nested containers containing tensors.
91+
4792
Documentation
4893
+++++++++++++
4994

_doc/_static/logo.png

2.11 KB
Loading

_doc/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ API of onnx_diagnostic
1616
cache_helpers
1717
ext_test_case
1818
helpers
19+
onnx_tools
1920
ort_session
2021
torch_test_helper
2122

_doc/conf.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,6 @@
8585
"matplotlib": ("https://matplotlib.org/stable/", None),
8686
"numpy": ("https://numpy.org/doc/stable", None),
8787
"onnx": ("https://onnx.ai/onnx/", None),
88-
"onnx_diagnostic": (
89-
"https://sdpython.github.io/doc/onnx-diagnostic/dev/",
90-
None,
91-
),
9288
"onnx_array_api": ("https://sdpython.github.io/doc/onnx-array-api/dev/", None),
9389
"onnx_extended": ("https://sdpython.github.io/doc/onnx-extended/dev/", None),
9490
"pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
@@ -118,8 +114,8 @@
118114
("py:class", "torch.utils._pytree.KeyEntry"),
119115
("py:class", "torch.utils._pytree.TreeSpec"),
120116
("py:class", "transformers.cache_utils.Cache"),
121-
# ("py:class", "transformers.cache_utils.DynamicCache"),
122-
# ("py:class", "transformers.cache_utils.MambaCache"),
117+
("py:class", "transformers.cache_utils.DynamicCache"),
118+
("py:class", "transformers.cache_utils.MambaCache"),
123119
("py:func", "torch.export._draft_export.draft_export"),
124120
("py:func", "torch._export.tools.report_exportability"),
125121
]
@@ -128,7 +124,9 @@
128124
("py:func", ".*numpy[.].*"),
129125
("py:func", ".*scipy[.].*"),
130126
# ("py:func", ".*torch.ops.higher_order.*"),
127+
("py:class", ".*numpy._typing[.].*"),
131128
("py:class", ".*onnxruntime[.].*"),
129+
("py:meth", ".*onnxruntime[.].*"),
132130
]
133131

134132

@@ -148,7 +146,7 @@
148146
# errors
149147
"abort_on_example_error": True,
150148
# recommendation
151-
"recommender": {"enable": True, "n_examples": 5, "min_df": 3, "max_df": 0.9},
149+
"recommender": {"enable": True, "n_examples": 3, "min_df": 3, "max_df": 0.9},
152150
# ignore capture for matplotib axes
153151
"ignore_repr_types": "matplotlib\\.(text|axes)",
154152
# robubstness

_doc/examples/plot_exporter_exporter_dynamic_shapes_auto.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
"""
2-
.. _l-plot-exporter-dynamic_shapes:
3-
42
Use DYNAMIC or AUTO when dynamic shapes has constraints
53
=======================================================
64

_doc/index.rst

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
onnx-diagnostic: fuzzy work
3-
===================================
2+
onnx-diagnostic: investigate onnx models
3+
========================================
44

55
.. image:: https://github.com/sdpython/onnx-diagnostic/actions/workflows/documentation.yml/badge.svg
66
:target: https://github.com/sdpython/onnx-diagnostic/actions/workflows/documentation.yml
@@ -36,6 +36,7 @@ Source are `sdpython/onnx-diagnostic
3636
:caption: Contents
3737

3838
api/index
39+
galleries
3940

4041
.. toctree::
4142
:maxdepth: 1
@@ -44,6 +45,45 @@ Source are `sdpython/onnx-diagnostic
4445
CHANGELOGS
4546
license
4647

48+
49+
**Some usefuls tools**
50+
51+
.. code-block:: python
52+
53+
import torch
54+
from onnx_diagnostic.helpers import string_type
55+
56+
inputs = (
57+
torch.rand((3, 4), dtype=torch.float16),
58+
[
59+
torch.rand((5, 6), dtype=torch.float16),
60+
torch.rand((5, 6, 7), dtype=torch.float16),
61+
]
62+
)
63+
64+
# with shapes
65+
print(string_type(inputs, with_shape=True))
66+
67+
::
68+
69+
>>> (T10s3x4,#2[T10s5x6,T10s5x6x7])
70+
71+
.. code-block:: python
72+
73+
import onnx
74+
from onnx_diagnostic.helpers import onnx_dtype_name
75+
76+
itype = onnx.TensorProto.BFLOAT16
77+
print(onnx_dtype_name(itype))
78+
print(onnx_dtype_name(7))
79+
80+
::
81+
82+
>>> BFLOAT16
83+
>>> INT64
84+
85+
:func:`onnx_diagnostic.helpers.max_diff`, ...
86+
4787
The documentation was updated on:
4888

4989
.. runpython::

onnx_diagnostic/ext_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
import glob
7-
import importlib
7+
import importlib.util
88
import logging
99
import os
1010
import re

onnx_diagnostic/helpers.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import sys
66
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
77
import numpy as np
8+
import numpy.typing as npt
89
from onnx import (
910
AttributeProto,
10-
DataType,
1111
FunctionProto,
1212
GraphProto,
1313
ModelProto,
@@ -87,7 +87,7 @@ def size_type(dtype: Any) -> int:
8787
raise AssertionError(f"Unexpected dtype={dtype}")
8888

8989

90-
def tensor_dtype_to_np_dtype(tensor_dtype: DataType) -> np.dtype:
90+
def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype:
9191
"""
9292
Converts a TensorProto's data_type to corresponding numpy dtype.
9393
It can be used while making tensor.
@@ -105,7 +105,7 @@ def tensor_dtype_to_np_dtype(tensor_dtype: DataType) -> np.dtype:
105105
f"ml_dtypes can be used."
106106
) from e
107107

108-
mapping = {
108+
mapping: Dict[int, np.dtype] = {
109109
TensorProto.BFLOAT16: ml_dtypes.bfloat16,
110110
TensorProto.FLOAT8E4M3FN: ml_dtypes.float8_e4m3fn,
111111
TensorProto.FLOAT8E4M3FNUZ: ml_dtypes.float8_e4m3fnuz,
@@ -142,7 +142,30 @@ def string_type(
142142
:showcode:
143143
144144
from onnx_diagnostic.helpers import string_type
145+
145146
print(string_type((1, ["r", 6.6])))
147+
148+
With pytorch:
149+
150+
.. runpython::
151+
:showcode:
152+
153+
import torch
154+
from onnx_diagnostic.helpers import string_type
155+
156+
inputs = (
157+
torch.rand((3, 4), dtype=torch.float16),
158+
[
159+
torch.rand((5, 6), dtype=torch.float16),
160+
torch.rand((5, 6, 7), dtype=torch.float16),
161+
]
162+
)
163+
164+
# with shapes
165+
print(string_type(inputs, with_shape=True))
166+
167+
# with min max
168+
print(string_type(inputs, with_shape=True, with_min_max=True))
146169
"""
147170
if obj is None:
148171
return "None"
@@ -465,7 +488,19 @@ def string_sig(f: Callable, kwargs: Optional[Dict[str, Any]] = None) -> str:
465488

466489
@functools.cache
467490
def onnx_dtype_name(itype: int) -> str:
468-
"""Returns the ONNX name for a specific element type."""
491+
"""
492+
Returns the ONNX name for a specific element type.
493+
494+
.. runpython::
495+
:showcode:
496+
497+
import onnx
498+
from onnx_diagnostic.helpers import onnx_dtype_name
499+
500+
itype = onnx.TensorProto.BFLOAT16
501+
print(onnx_dtype_name(itype))
502+
print(onnx_dtype_name(7))
503+
"""
469504
for k in dir(TensorProto):
470505
v = getattr(TensorProto, k)
471506
if v == itype:
@@ -477,19 +512,24 @@ def pretty_onnx(
477512
onx: Union[FunctionProto, GraphProto, ModelProto, ValueInfoProto, str],
478513
with_attributes: bool = False,
479514
highlight: Optional[Set[str]] = None,
515+
shape_inference: bool = False,
480516
) -> str:
481517
"""
482518
Displays an onnx prot in a better way.
483519
484520
:param with_attributes: displays attributes as well, if only a node is printed
485521
:param highlight: to highlight some names
522+
:param shape_inference: run shape inference before printing the model
486523
:return: text
487524
"""
488525
assert onx is not None, "onx cannot be None"
489526
if isinstance(onx, str):
490527
onx = onnx_load(onx, load_external_data=False)
491528
assert onx is not None, "onx cannot be None"
492529

530+
if shape_inference:
531+
onx = onx.shape_inference.infer_shapes(onx)
532+
493533
if isinstance(onx, ValueInfoProto):
494534
name = onx.name
495535
itype = onx.type.tensor_type.elem_type
@@ -577,7 +617,7 @@ def make_hash(obj: Any) -> str:
577617

578618
def get_onnx_signature(model: ModelProto) -> Tuple[Tuple[str, Any], ...]:
579619
"""
580-
Produces a tuple of tuples correspinding to the signatures.
620+
Produces a tuple of tuples corresponding to the signatures.
581621
582622
:param model: model
583623
:return: signature
@@ -611,7 +651,7 @@ def convert_endian(tensor: TensorProto) -> None:
611651
tensor.raw_data = np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap().tobytes()
612652

613653

614-
def from_array_ml_dtypes(arr: np.ndarray, name: Optional[str] = None) -> TensorProto:
654+
def from_array_ml_dtypes(arr: npt.ArrayLike, name: Optional[str] = None) -> TensorProto:
615655
"""
616656
Converts a numpy array to a tensor def assuming the dtype
617657
is defined in ml_dtypes.
@@ -625,7 +665,7 @@ def from_array_ml_dtypes(arr: np.ndarray, name: Optional[str] = None) -> TensorP
625665
"""
626666
import ml_dtypes
627667

628-
assert isinstance(arr, np.ndarray), f"arr must be of type np.ndarray, got {type(arr)}"
668+
assert isinstance(arr, np.ndarray), f"arr must be of type numpy.ndarray, got {type(arr)}"
629669

630670
tensor = TensorProto()
631671
tensor.dims.extend(arr.shape)
@@ -651,9 +691,9 @@ def from_array_ml_dtypes(arr: np.ndarray, name: Optional[str] = None) -> TensorP
651691
return tensor
652692

653693

654-
def from_array_extended(tensor: np.ndarray, name: Optional[str] = None) -> TensorProto:
694+
def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> TensorProto:
655695
"""
656-
Converts an array into a TensorProto.
696+
Converts an array into a :class:`onnx.TensorProto`.
657697
658698
:param tensor: numpy array
659699
:param name: name

0 commit comments

Comments
 (0)