@@ -1547,6 +1547,107 @@ def _cmd_compare(argv: List[Any]):
15471547 print (ObsComparePair .to_str (pair_cmp ))
15481548
15491549
1550+ def get_parser_optimize () -> ArgumentParser :
1551+ parser = ArgumentParser (
1552+ prog = "optimize" ,
1553+ formatter_class = RawTextHelpFormatter ,
1554+ description = textwrap .dedent (
1555+ """
1556+ Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs
1557+ and replaces them by the corresponding nodes. It also does basic optimization
1558+ such as removing identity nodes or unused nodes.
1559+ """
1560+ ),
1561+ epilog = textwrap .dedent (
1562+ """
1563+ The goal is to make the model faster.
1564+ Argument patterns defines the patterns to apply or the set of patterns.
1565+ It is possible to show statistics or to remove a particular pattern.
1566+ Here are some environment variables which can be used to trigger
1567+ these displays.
1568+
1569+ Available options algorithms, default and default+runtime:
1570+
1571+ - DROPPATTERN=<pattern1,patterns2,...>: do not apply
1572+ those patterns when optimizing a model
1573+ - DUMPPATTERNS=<folder>: dumps all matched and applied
1574+ nodes when a pattern is applied
1575+ - PATTERN=<pattern1,pattern2,...>: increase verbosity for specific
1576+ patterns to understand why one pattern was not applied,
1577+ this shows which line is rejecting a pattern if it seems one pattern was missed
1578+ """
1579+ ),
1580+ )
1581+ parser .add_argument (
1582+ "algorithm" ,
1583+ choices = ["ir" , "os_ort" , "slim" , "default" , "default+onnxruntime" ],
1584+ help = "algorithm or patterns optimization to apply" ,
1585+ )
1586+ parser .add_argument ("input" , type = str , help = "onnx model to optimize" )
1587+ parser .add_argument (
1588+ "-o" ,
1589+ "--output" ,
1590+ type = str ,
1591+ required = False ,
1592+ help = "onnx model to output, if empty, if adds .opt-{algorithm}.onnx to the name" ,
1593+ )
1594+ parser .add_argument (
1595+ "-v" ,
1596+ "--verbose" ,
1597+ default = 0 ,
1598+ required = False ,
1599+ type = int ,
1600+ help = "verbosity" ,
1601+ )
1602+ parser .add_argument (
1603+ "--infer-shapes" ,
1604+ default = True ,
1605+ action = BooleanOptionalAction ,
1606+ help = "infer shapes before optimizing the model" ,
1607+ )
1608+ parser .add_argument (
1609+ "--processor" ,
1610+ default = "" ,
1611+ help = textwrap .dedent (
1612+ """
1613+ optimization for a specific processor, CPU, CUDA or both CPU,CUDA,
1614+ some operators are only available in one processor, it might be not used
1615+ with all
1616+ """
1617+ ).strip ("\n " ),
1618+ )
1619+ parser .add_argument (
1620+ "--remove-shape-info" ,
1621+ default = True ,
1622+ action = BooleanOptionalAction ,
1623+ help = "remove shape information before outputting the model" ,
1624+ )
1625+ return parser
1626+
1627+
1628+ def _cmd_optimize (argv : List [Any ]):
1629+ parser = get_parser_optimize ()
1630+ args = parser .parse_args (argv [1 :])
1631+
1632+ from .helpers .optim_helper import optimize_model
1633+
1634+ output = (
1635+ args .output
1636+ if args .output
1637+ else f"{ os .path .splitext (args .input )[0 ]} .o-{ args .algorithm } .onnx"
1638+ )
1639+
1640+ optimize_model (
1641+ args .algorithm ,
1642+ args .input ,
1643+ output = output ,
1644+ verbose = args .verbose ,
1645+ processor = args .processor ,
1646+ infer_shapes = args .infer_shapes ,
1647+ remove_shape_info = args .remove_shape_info ,
1648+ )
1649+
1650+
15501651#############
15511652# main parser
15521653#############
@@ -1563,16 +1664,17 @@ def get_main_parser() -> ArgumentParser:
15631664 to get help for a specific command.
15641665
15651666 agg - aggregates statistics from multiple files
1566- config - prints a configuration for a model id
1667+ config - prints a configuration for a model id (on HuggingFace Hub)
15671668 dot - converts an onnx model into dot format
15681669 exportsample - produces a code to export a model
15691670 find - find node consuming or producing a result
1570- lighten - makes an onnx model lighter by removing the weights,
1671+ lighten - makes an onnx model lighter by removing the weights
1672+ optimize - optimizes an onnx model
15711673 print - prints the model on standard output
15721674 sbs - compares an exported program and a onnx model
15731675 stats - produces statistics on a model
15741676 unlighten - restores an onnx model produces by the previous experiment
1575- validate - validate a model
1677+ validate - validate a model (knowing its model id on HuggginFace Hub)
15761678 """
15771679 ),
15781680 )
@@ -1585,6 +1687,7 @@ def get_main_parser() -> ArgumentParser:
15851687 "exportsample" ,
15861688 "find" ,
15871689 "lighten" ,
1690+ "optimize" ,
15881691 "print" ,
15891692 "sbs" ,
15901693 "stats" ,
@@ -1605,6 +1708,7 @@ def main(argv: Optional[List[Any]] = None):
16051708 exportsample = _cmd_export_sample ,
16061709 find = _cmd_find ,
16071710 lighten = _cmd_lighten ,
1711+ optimize = _cmd_optimize ,
16081712 print = _cmd_print ,
16091713 sbs = _cmd_sbs ,
16101714 stats = _cmd_stats ,
@@ -1631,6 +1735,7 @@ def main(argv: Optional[List[Any]] = None):
16311735 exportsample = lambda : get_parser_validate ("exportsample" ), # type: ignore[operator]
16321736 find = get_parser_find ,
16331737 lighten = get_parser_lighten ,
1738+ optimize = get_parser_optimize ,
16341739 print = get_parser_print ,
16351740 sbs = get_parser_sbs ,
16361741 stats = get_parser_stats ,
0 commit comments