Skip to content

Commit 90c9064

Browse files
committed
fix issues
1 parent fb81896 commit 90c9064

File tree

5 files changed

+177
-36
lines changed

5 files changed

+177
-36
lines changed

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,10 @@ def test_iterate_function(self):
246246
def test_statistics(self):
247247
rnd = np.random.rand(40, 50).astype(np.float16)
248248
stat = tensor_statistics(rnd)
249-
print(stat)
249+
self.assertEqual(stat["stype"], "FLOAT16")
250+
rnd = np.random.rand(40, 50).astype(np.float32)
251+
stat = tensor_statistics(rnd)
252+
self.assertEqual(stat["stype"], "FLOAT")
250253

251254

252255
if __name__ == "__main__":

_unittests/ut_xrun_doc/test_command_lines.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
get_parser_find,
99
get_parser_lighten,
1010
get_parser_print,
11+
get_parser_stats,
1112
get_parser_unlighten,
1213
get_parser_validate,
1314
)
@@ -63,6 +64,13 @@ def test_parser_validate(self):
6364
text = st.getvalue()
6465
self.assertIn("mid", text)
6566

67+
def test_parser_stats(self):
68+
st = StringIO()
69+
with redirect_stdout(st):
70+
get_parser_stats().print_help()
71+
text = st.getvalue()
72+
self.assertIn("input", text)
73+
6674

6775
if __name__ == "__main__":
6876
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_command_lines_exe.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ def test_parser_print(self):
2020
text = st.getvalue()
2121
self.assertIn("Add", text)
2222

23+
def test_parser_stats(self):
24+
output = self.get_dump_file("test_parser_stats.xlsx")
25+
st = StringIO()
26+
with redirect_stdout(st):
27+
main(["stats", "-i", self.dummy_path, "-o", output])
28+
text = st.getvalue()
29+
self.assertIn("processing", text)
30+
self.assertExists(output)
31+
2332
def test_parser_find(self):
2433
st = StringIO()
2534
with redirect_stdout(st):

onnx_diagnostic/_command_lines_parser.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import argparse
22
import json
3+
import os
4+
import re
35
import sys
46
import textwrap
57
import onnx
@@ -425,6 +427,106 @@ def _cmd_validate(argv: List[Any]):
425427
print(f":{k},{v};")
426428

427429

430+
def get_parser_stats() -> ArgumentParser:
431+
parser = ArgumentParser(
432+
prog="stats",
433+
description=dedent(
434+
"""
435+
Prints out statistics on an ONNX model.
436+
"""
437+
),
438+
epilog="",
439+
)
440+
parser.add_argument(
441+
"-i",
442+
"--input",
443+
type=str,
444+
required=True,
445+
help="ONNX file",
446+
)
447+
parser.add_argument(
448+
"-o",
449+
"--output",
450+
required=False,
451+
default="",
452+
help="outputs the statistics in a file",
453+
)
454+
parser.add_argument(
455+
"-v",
456+
"--verbose",
457+
required=False,
458+
default=1,
459+
type=int,
460+
help="verbosity",
461+
)
462+
parser.add_argument(
463+
"-e",
464+
"--end",
465+
required=False,
466+
default=-1,
467+
type=int,
468+
help="ends after this many tensors",
469+
)
470+
parser.add_argument(
471+
"-b",
472+
"--begin",
473+
required=False,
474+
default=0,
475+
type=int,
476+
help="starts after this many tensors",
477+
)
478+
parser.add_argument(
479+
"-r",
480+
"--regex",
481+
required=False,
482+
default="",
483+
type=str,
484+
help="keeps only tensors whose name verifies "
485+
"this regular expression, empty = no filter",
486+
)
487+
return parser
488+
489+
490+
def _cmd_stats(argv: List[Any]):
491+
from .helpers.onnx_helper import iterator_initializer_constant, tensor_statistics
492+
493+
parser = get_parser_stats()
494+
args = parser.parse_args(argv[1:])
495+
assert os.path.exists(args.input), f"Missing filename {args.input!r}"
496+
if args.verbose:
497+
print(f"Loading {args.input}")
498+
onx = onnx.load(args.input)
499+
reg = re.compile(args.regex) if args.regex else None
500+
data = []
501+
for index, (name, init) in enumerate(iterator_initializer_constant(onx)):
502+
if reg and not reg.seach(name):
503+
continue
504+
if index < args.begin:
505+
continue
506+
if args.end > 0 and index >= args.end:
507+
break
508+
if args.verbose:
509+
print(f"processing {index + 1}: {name!r}")
510+
stats = tensor_statistics(init)
511+
if not args.output:
512+
print(f"{name}: {stats}")
513+
stats["name"] = name
514+
data.append(stats)
515+
if args.output:
516+
if args.verbose:
517+
print(f"saving into {args.output!r}")
518+
import pandas
519+
520+
df = pandas.DataFrame(data)
521+
ext = os.path.splitext(args.output)
522+
if ext[-1] == ".xlsx":
523+
df.to_excel(args.output, index=False)
524+
else:
525+
df.to_csv(args.output, index=False)
526+
if args.verbose:
527+
print("done.")
528+
529+
428530
def get_main_parser() -> ArgumentParser:
429531
parser = ArgumentParser(
430532
prog="onnx_diagnostic",
@@ -441,12 +543,13 @@ def get_main_parser() -> ArgumentParser:
441543
unlighten - restores an onnx model produces by the previous experiment
442544
print - prints the model on standard output
443545
validate - validate a model
546+
stats - produces statistics on a model
444547
"""
445548
),
446549
)
447550
parser.add_argument(
448551
"cmd",
449-
choices=["config", "find", "lighten", "print", "unlighten", "validate"],
552+
choices=["config", "find", "lighten", "print", "stats", "unlighten", "validate"],
450553
help="Selects a command.",
451554
)
452555
return parser
@@ -460,6 +563,7 @@ def main(argv: Optional[List[Any]] = None):
460563
find=_cmd_find,
461564
config=_cmd_config,
462565
validate=_cmd_validate,
566+
stats=_cmd_stats,
463567
)
464568

465569
if argv is None:
@@ -480,6 +584,7 @@ def main(argv: Optional[List[Any]] = None):
480584
find=get_parser_find,
481585
config=get_parser_config,
482586
validate=get_parser_validate,
587+
stats=get_parser_stats,
483588
)
484589
cmd = argv[0]
485590
if cmd not in parsers:

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import os
44
import sys
5+
import warnings
56
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
67
import numpy as np
78
import numpy.typing as npt
@@ -330,9 +331,10 @@ def onnx_dtype_name(itype: int) -> str:
330331
print(onnx_dtype_name(7))
331332
"""
332333
for k in dir(TensorProto):
333-
v = getattr(TensorProto, k)
334-
if v == itype:
335-
return k
334+
if "FLOAT" in k or "INT" in k or "TEXT" in k or "BOOL" in k:
335+
v = getattr(TensorProto, k)
336+
if v == itype:
337+
return k
336338
raise ValueError(f"Unexpected value itype: {itype}")
337339

338340

@@ -841,47 +843,61 @@ def tensor_statistics(tensor: Union[np.ndarray, TensorProto]) -> Dict[str, Union
841843

842844
if isinstance(tensor, TensorProto):
843845
tensor = to_array_extended(tensor)
846+
itype = np_dtype_to_tensor_dtype(tensor.dtype)
844847
stat = dict(
845848
mean=float(tensor.mean()),
846849
std=float(tensor.std()),
847850
shape="x".join(map(str, tensor.shape)),
848851
numel=tensor.size,
849852
size=tensor.size * size_type(tensor.dtype),
850-
itype=np_dtype_to_tensor_dtype(tensor.dtype),
851-
stype=onnx_dtype_name(np_dtype_to_tensor_dtype(tensor.dtype)),
853+
itype=itype,
854+
stype=onnx_dtype_name(itype),
852855
min=float(tensor.min()),
853856
max=float(tensor.max()),
854-
nnan=np.isnan(tensor).sum(),
857+
nnan=float(np.isnan(tensor).sum()),
855858
)
856859

857-
hist = np.array(
858-
[
859-
0,
860-
1e-10,
861-
1e-8,
862-
1e-7,
863-
1e-6,
864-
1e-5,
865-
0.0001,
866-
0.001,
867-
0.01,
868-
0.1,
869-
0.5,
870-
1,
871-
1.96,
872-
10,
873-
100,
874-
1e3,
875-
1e4,
876-
1e5,
877-
1e6,
878-
1e7,
879-
1e8,
880-
1e10,
881-
1e50,
882-
],
883-
dtype=tensor.dtype,
884-
)
860+
if tensor.size < 8:
861+
return stat
862+
863+
with warnings.catch_warnings():
864+
warnings.simplefilter("ignore")
865+
try:
866+
hist = np.array(
867+
[
868+
0,
869+
1e-10,
870+
1e-8,
871+
1e-7,
872+
1e-6,
873+
1e-5,
874+
0.0001,
875+
0.001,
876+
0.01,
877+
0.1,
878+
0.5,
879+
1,
880+
1.96,
881+
10,
882+
1e2,
883+
1e3,
884+
1e4,
885+
1e5,
886+
1e6,
887+
1e7,
888+
1e8,
889+
1e10,
890+
1e50,
891+
],
892+
dtype=tensor.dtype,
893+
)
894+
except OverflowError as e:
895+
from .helper import string_type
896+
897+
raise ValueError(
898+
f"Unable to convert one value into {tensor.dtype}, "
899+
f"tensor={string_type(tensor, with_shape=True)}"
900+
) from e
885901
hist = np.array(sorted(set(hist[~np.isinf(hist)])), dtype=tensor.dtype)
886902
ind = np.digitize(np.abs(tensor).reshape((-1,)), hist, right=True)
887903
cou = np.bincount(ind, minlength=ind.shape[0] + 1)

0 commit comments

Comments
 (0)