11import contextlib
22from collections .abc import Iterable
3- from typing import Any , Optional , Tuple , Union
3+ from typing import Any , Callable , List , Optional , Tuple , Union
44import numpy as np
55import torch
66from .helper import string_type
1212)
1313
1414
15- def _forward_ (* args , _f = None , _context = None , ** kwargs ):
15+ def _forward_ (* args , _f = None , _fprint = string_type , _prefix = "" , _context = None , ** kwargs ):
1616 assert _f is not None , "_f cannot be None"
1717 assert _context is not None , "_context cannot be None"
18+ indent = " " * (len (_prefix ) - len (_prefix .lstrip ()))
19+ _prefix = _prefix .lstrip ()
1820 print (
19- f"-- -- stolen forward for class { _context ['class_name' ]} "
21+ f"{ indent } + { _prefix } -- stolen forward for class { _context ['class_name' ]} "
2022 f"-- iteration { _context ['iteration' ]} "
2123 )
2224 kws = dict (
@@ -25,36 +27,54 @@ def _forward_(*args, _f=None, _context=None, **kwargs):
2527 )
2628 if not hasattr (torch .compiler , "is_exporting" ) or not torch .compiler .is_exporting ():
2729 # torch.compiler.is_exporting requires torch>=2.7
28- print (f" <- args={ string_type (args , ** kws )} --- kwargs={ string_type (kwargs , ** kws )} " )
30+ print (f"{ indent } <- args={ _fprint (args , ** kws )} --- kwargs={ _fprint (kwargs , ** kws )} " )
2931 res = _f (* args , ** kwargs )
3032 if not hasattr (torch .compiler , "is_exporting" ) or not torch .compiler .is_exporting ():
31- print (" --" )
32- print (f" -> { string_type (res , ** kws )} " )
33- print ("." )
33+ print (f"{ indent } -> { _fprint (res , ** kws )} " )
34+ print (f"{ indent } -{ _prefix } ." )
3435 _context ["iteration" ] += 1
3536 return res
3637
3738
3839@contextlib .contextmanager
39- def steal_forward (model : torch .nn .Module , with_shape : bool = True , with_min_max : bool = False ):
40+ def steal_forward (
41+ model : Union [
42+ Union [torch .nn .Module , Tuple [str , torch .nn .Module ]],
43+ List [Union [torch .nn .Module , Tuple [str , torch .nn .Module ]]],
44+ ],
45+ fprint : Callable = string_type ,
46+ ** kwargs ,
47+ ):
4048 """
4149 The necessary modification to steem forward method and prints out inputs
4250 and outputs. See example :ref:`l-plot-tiny-llm-export`.
51+
52+ :param model: a model or a list of models to monitor,
53+ every model can also be a tuple(name, model), name is displayed well.
54+ :param fprint: function used to print out (or dump), by default, it is
55+ :func:`onnx_diagnostic.helpers.string_type`
56+ :param kwargs: additional parameters sent to :func:`onnx_diagnostic.helpers.string_type`
57+ or any other function defined by ``fprint``
4358 """
44- context = dict (
45- iteration = 0 ,
46- class_name = model .__class__ .__name__ ,
47- with_shape = with_shape ,
48- with_min_max = with_min_max ,
49- )
50- keep_model_forward = model .forward
51- model .forward = lambda * args , _f = keep_model_forward , _context = context , ** kwargs : _forward_ (
52- * args , _f = _f , _context = _context , ** kwargs
53- )
59+ context = dict (iteration = 0 , ** kwargs )
60+ if "with_shape" not in context and fprint == string_type :
61+ context ["with_shape" ] = True
62+ if not isinstance (model , list ):
63+ model = [model ]
64+ keep_model_forward = {}
65+ for mt in model :
66+ name , m = mt if isinstance (mt , tuple ) else ("" , mt )
67+ keep_model_forward [id (m )] = (m , m .forward )
68+ c = context .copy ()
69+ c ["class_name" ] = m .__class__ .__name__
70+ m .forward = lambda * args , _f = m .forward , _fp = fprint , _c = c , _p = name , ** kws : _forward_ (
71+ * args , _f = _f , _fprint = _fp , _context = _c , _prefix = _p , ** kws
72+ )
5473 try :
5574 yield
5675 finally :
57- model .forward = keep_model_forward
76+ for f in keep_model_forward .values ():
77+ f [0 ].forward = f [1 ]
5878
5979
6080def is_torchdynamo_exporting () -> bool :
0 commit comments