11import argparse
22import json
3+ import os
4+ import re
35import sys
46import textwrap
57import onnx
@@ -425,6 +427,106 @@ def _cmd_validate(argv: List[Any]):
425427 print (f":{ k } ,{ v } ;" )
426428
427429
430+ def get_parser_stats () -> ArgumentParser :
431+ parser = ArgumentParser (
432+ prog = "stats" ,
433+ description = dedent (
434+ """
435+ Prints out statistics on an ONNX model.
436+ """
437+ ),
438+ epilog = "" ,
439+ )
440+ parser .add_argument (
441+ "-i" ,
442+ "--input" ,
443+ type = str ,
444+ required = True ,
445+ help = "ONNX file" ,
446+ )
447+ parser .add_argument (
448+ "-o" ,
449+ "--output" ,
450+ required = False ,
451+ default = "" ,
452+ help = "outputs the statistics in a file" ,
453+ )
454+ parser .add_argument (
455+ "-v" ,
456+ "--verbose" ,
457+ required = False ,
458+ default = 1 ,
459+ type = int ,
460+ help = "verbosity" ,
461+ )
462+ parser .add_argument (
463+ "-e" ,
464+ "--end" ,
465+ required = False ,
466+ default = - 1 ,
467+ type = int ,
468+ help = "ends after this many tensors" ,
469+ )
470+ parser .add_argument (
471+ "-b" ,
472+ "--begin" ,
473+ required = False ,
474+ default = 0 ,
475+ type = int ,
476+ help = "starts after this many tensors" ,
477+ )
478+ parser .add_argument (
479+ "-r" ,
480+ "--regex" ,
481+ required = False ,
482+ default = "" ,
483+ type = str ,
484+ help = "keeps only tensors whose name verifies "
485+ "this regular expression, empty = no filter" ,
486+ )
487+ return parser
488+
489+
490+ def _cmd_stats (argv : List [Any ]):
491+ from .helpers .onnx_helper import iterator_initializer_constant , tensor_statistics
492+
493+ parser = get_parser_stats ()
494+ args = parser .parse_args (argv [1 :])
495+ assert os .path .exists (args .input ), f"Missing filename { args .input !r} "
496+ if args .verbose :
497+ print (f"Loading { args .input } " )
498+ onx = onnx .load (args .input )
499+ reg = re .compile (args .regex ) if args .regex else None
500+ data = []
501+ for index , (name , init ) in enumerate (iterator_initializer_constant (onx )):
502+ if reg and not reg .seach (name ):
503+ continue
504+ if index < args .begin :
505+ continue
506+ if args .end > 0 and index >= args .end :
507+ break
508+ if args .verbose :
509+ print (f"processing { index + 1 } : { name !r} " )
510+ stats = tensor_statistics (init )
511+ if not args .output :
512+ print (f"{ name } : { stats } " )
513+ stats ["name" ] = name
514+ data .append (stats )
515+ if args .output :
516+ if args .verbose :
517+ print (f"saving into { args .output !r} " )
518+ import pandas
519+
520+ df = pandas .DataFrame (data )
521+ ext = os .path .splitext (args .output )
522+ if ext [- 1 ] == ".xlsx" :
523+ df .to_excel (args .output , index = False )
524+ else :
525+ df .to_csv (args .output , index = False )
526+ if args .verbose :
527+ print ("done." )
528+
529+
428530def get_main_parser () -> ArgumentParser :
429531 parser = ArgumentParser (
430532 prog = "onnx_diagnostic" ,
@@ -441,12 +543,13 @@ def get_main_parser() -> ArgumentParser:
441543 unlighten - restores an onnx model produces by the previous experiment
442544 print - prints the model on standard output
443545 validate - validate a model
546+ stats - produces statistics on a model
444547 """
445548 ),
446549 )
447550 parser .add_argument (
448551 "cmd" ,
449- choices = ["config" , "find" , "lighten" , "print" , "unlighten" , "validate" ],
552+ choices = ["config" , "find" , "lighten" , "print" , "stats" , " unlighten" , "validate" ],
450553 help = "Selects a command." ,
451554 )
452555 return parser
@@ -460,6 +563,7 @@ def main(argv: Optional[List[Any]] = None):
460563 find = _cmd_find ,
461564 config = _cmd_config ,
462565 validate = _cmd_validate ,
566+ stats = _cmd_stats ,
463567 )
464568
465569 if argv is None :
@@ -480,6 +584,7 @@ def main(argv: Optional[List[Any]] = None):
480584 find = get_parser_find ,
481585 config = get_parser_config ,
482586 validate = get_parser_validate ,
587+ stats = get_parser_stats ,
483588 )
484589 cmd = argv [0 ]
485590 if cmd not in parsers :
0 commit comments