@@ -44,6 +44,7 @@ class TorchOnnxEvaluator:
4444 :param providers: where to run the model
4545 :param opsets: needed if proto is a graph
4646 :param functions: known local functions
47+ :param verbose: verbosity level
4748
4849 The class holds the following attributes:
4950
@@ -56,6 +57,7 @@ class TorchOnnxEvaluator:
5657 * `last_used`: contains the list of intermediate results,
5758 to remove after every node execution,
5859 this avoid the memory to grow too much
60+ * `functions`: local functions
5961
6062 The class is not multithreaded. `runtime_info` gets updated
6163 by the the class. The list of available kernels is returned by function
@@ -68,12 +70,15 @@ def __init__(
6870 providers : Tuple [str , ...] = ("CPUExecutionProvider" ,),
6971 opsets : Optional [Dict [str , int ]] = None ,
7072 local_functions : Optional [Dict [Tuple [str , str ], "TorchOnnxEvaluator" ]] = None ,
73+ verbose : int = 0 ,
7174 ):
75+ assert verbose
7276 self .providers = providers
7377 self .constants : Dict [str , torch .Tensor ] = {}
7478 self .kernels : List [Optional [torch_ops .OpRun ]] = []
7579 self .functions = local_functions .copy () if local_functions else {}
7680 self .CPU = torch .tensor ([0 ]).to ("cpu" ).device
81+ self .verbose = verbose
7782 if "CUDAExecutionProvider" in providers :
7883 self .CUDA = torch .tensor ([0 ]).to ("cuda" ).device
7984 self .default_device = self .CUDA
@@ -87,8 +92,11 @@ def __init__(
8792 assert not proto .graph .sparse_initializer , "sparse_initializer not support yet"
8893 self .opsets = {d .domain : d .version for d in proto .opset_import }
8994 for f in proto .functions :
90- self .functions [f .domain , f .name ] = TorchOnnxEvaluator (
91- f , providers = providers , local_functions = self .functions
95+ self .functions [f .domain , f .name ] = self .__class__ (
96+ f ,
97+ providers = providers ,
98+ local_functions = self .functions ,
99+ verbose = self .verbose ,
92100 )
93101 self ._build_initializers (proto .graph .initializer )
94102 self ._build_initializers (proto .graph .node )
@@ -206,22 +214,36 @@ def run(
206214 if not r .has_value :
207215 r .set_value (
208216 torch_ops .OpRunValue (
209- v .to (self .CUDA ) if r .is_shape and self .on_cuda else v , True
217+ v .to (self .CUDA ) if not r .is_shape and self .on_cuda else v ,
218+ is_constant = True ,
219+ may_cpu = len (v .shape ) == 1 and v .numel () < 8 and v .dtype == torch .int64 ,
210220 )
211221 )
222+ if self .verbose :
223+ print (f"+C { r .name } : { r .string_type ()} " )
212224
213225 # inputs
214226 for k , v in feeds .items ():
215227 r = self .runtime_info [k ]
216228 r .set_value (
217229 torch_ops .OpRunValue (
218- v .to (self .CUDA ) if r .is_shape and self .on_cuda else v , False
230+ v .to (self .CUDA ) if not r .is_shape and self .on_cuda else v ,
231+ is_constant = False ,
232+ may_cpu = len (v .shape ) == 1 and v .numel () < 8 and v .dtype == torch .int64 ,
219233 )
220234 )
235+ if self .verbose :
236+ print (f"+I { r .name } : { r .string_type ()} " )
221237
222238 # node execution
223239 for it , kernel in enumerate (self .kernels ):
224240 if kernel is not None :
241+ if self .verbose :
242+ print (
243+ f"{ kernel .__class__ .__name__ } "
244+ f"({ ', ' .join (kernel .input )} ) -> "
245+ f"{ ', ' .join (kernel .output )} "
246+ )
225247 # kernel execution
226248 inputs = [(self .runtime_info [i ].value if i else None ) for i in kernel .input ]
227249 if kernel .has_subgraphs ():
@@ -236,26 +258,42 @@ def run(
236258 )
237259 for name , t in zip (kernel .output , res ):
238260 self .runtime_info [name ].set_value (t )
261+ if self .verbose :
262+ for name in kernel .output :
263+ print (f"+R { name } : { self .runtime_info [name ].string_type ()} " )
239264 else :
240265 assert isinstance (
241266 res , torch_ops .OpRunValue
242267 ), f"Unexpected output type { type (res )} for kernel { type (kernel )} ."
243268 self .runtime_info [kernel .output [0 ]].set_value (res )
269+ if self .verbose :
270+ print (
271+ f"+R { kernel .output [0 ]} : "
272+ f"{ self .runtime_info [kernel .output [0 ]].string_type ()} "
273+ )
244274
245275 # free intermediate results
246276 for name in self .last_used [it ]:
247277 self .runtime_info [name ].clean_value ()
278+ if self .verbose :
279+ print (f"- clean { name } " )
248280
249281 assert all (
250282 self .runtime_info [o ].value is not None for o in outputs
251283 ), "Not implemented yet when one output is None."
252284 fres = [self .runtime_info [o ].value .tensor for o in outputs ] # type: ignore[union-attr]
285+ if self .verbose :
286+ print (f"++ outputs { ', ' .join (outputs )} " )
253287
254288 # clean previous execution
255289 for k in feeds :
256290 self .runtime_info [k ].clean_value ()
291+ if self .verbose :
292+ print (f"- clean { k } " )
257293 for o in outputs :
258294 self .runtime_info [o ].clean_value ()
295+ if self .verbose :
296+ print (f"- clean { o } " )
259297
260298 if use_numpy :
261299 return [None if a is None else a .detach ().cpu ().numpy () for a in fres ]
@@ -285,7 +323,9 @@ def run_with_values(
285323 if not r .has_value :
286324 r .set_value (
287325 torch_ops .OpRunValue (
288- v .to (self .CUDA ) if r .is_shape and self .on_cuda else v , True
326+ v .to (self .CUDA ) if r .is_shape is False and self .on_cuda else v ,
327+ is_constant = True ,
328+ may_cpu = len (v .shape ) == 1 and v .numel () < 8 and v .dtype == torch .int64 ,
289329 )
290330 )
291331
0 commit comments