@@ -43,6 +43,7 @@ class TorchOnnxEvaluator:
4343 :param proto: a proto
4444 :param providers: where to run the model
4545 :param opsets: needed if proto is a graph
46+ :param functions: known local functions
4647
4748 The class holds the following attributes:
4849
@@ -66,10 +67,12 @@ def __init__(
6667 proto : Union [onnx .FunctionProto , onnx .GraphProto , onnx .ModelProto ],
6768 providers : Tuple [str , ...] = ("CPUExecutionProvider" ,),
6869 opsets : Optional [Dict [str , int ]] = None ,
70+ local_functions : Optional [Dict [Tuple [str , str ], "TorchOnnxEvaluator" ]] = None ,
6971 ):
7072 self .providers = providers
7173 self .constants : Dict [str , torch .Tensor ] = {}
7274 self .kernels : List [Optional [torch_ops .OpRun ]] = []
75+ self .functions = local_functions .copy () if local_functions else {}
7376 self .CPU = torch .tensor ([0 ]).to ("cpu" ).device
7477 if "CUDAExecutionProvider" in providers :
7578 self .CUDA = torch .tensor ([0 ]).to ("cuda" ).device
@@ -83,6 +86,10 @@ def __init__(
8386 assert opsets is None , "proto is a model, opsets must be None in that case"
8487 assert not proto .graph .sparse_initializer , "sparse_initializer not support yet"
8588 self .opsets = {d .domain : d .version for d in proto .opset_import }
89+ for f in proto .functions :
90+ self .functions [f .domain , f .name ] = TorchOnnxEvaluator (
91+ f , providers = providers , local_functions = self .functions
92+ )
8693 self ._build_initializers (proto .graph .initializer )
8794 self ._build_initializers (proto .graph .node )
8895 self ._build_kernels (proto .graph .node )
@@ -138,24 +145,33 @@ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
138145 kernels = get_kernels ()
139146 self .kernels .clear ()
140147 for node in nodes :
148+ if (node .domain , node .op_type ) in self .functions :
149+ kernel = torch_ops .OpRunFunction (
150+ self .functions [node .domain , node .op_type ], node , self .opsets [node .domain ]
151+ )
152+ self .kernels .append (kernel )
153+ continue
154+
141155 if node .op_type == "Constant" and node .domain == "" :
142156 # Treated as a constant.
143157 self .kernels .append (None )
144158 continue
159+
145160 opset = self .opsets [node .domain ]
146161 key = node .domain , node .op_type , opset
147162 while key not in kernels and opset > 0 :
148163 opset -= 1
149164 key = node .domain , node .op_type , opset
150- assert (
151- key in kernels
152- ), f"Missing kernel for node type { node .op_type !r} from domain { node .domain !r} "
165+ assert key in kernels , (
166+ f"Missing kernel for node type { node .op_type !r} from domain { node .domain !r} , "
167+ f"local functions={ sorted (self .functions )} "
168+ )
153169 cls = kernels [key ]
154170 if cls .device_dependent ():
155- kernel = cls (node , opset , self .default_device ) # type: ignore[call-arg]
171+ kernel2 : torch_ops . OpRun = cls (node , opset , self .default_device ) # type: ignore[call-arg]
156172 else :
157- kernel = cls (node , opset )
158- self .kernels .append (kernel )
173+ kernel2 = cls (node , opset ) # type: ignore[assignment]
174+ self .kernels .append (kernel2 )
159175
160176 def run (
161177 self ,
@@ -165,7 +181,7 @@ def run(
165181 """
166182 Runs the ONNX model.
167183
168- :param outputs: outputs required:
184+ :param outputs: outputs required
169185 :param feeds: inputs
170186 :return: output tensors.
171187 """
@@ -218,7 +234,10 @@ def run(
218234 for name in self .last_used [it ]:
219235 self .runtime_info [name ].clean_value ()
220236
221- res = [self .runtime_info [o ].value .tensor for o in outputs ] # type: ignore[assignment, union-attr]
237+ assert all (
238+ self .runtime_info [o ].value is not None for o in outputs
239+ ), "Not implemented yet when one output is None."
240+ fres = [self .runtime_info [o ].value .tensor for o in outputs ] # type: ignore[union-attr]
222241
223242 # clean previous execution
224243 for k in feeds :
@@ -227,5 +246,71 @@ def run(
227246 self .runtime_info [o ].clean_value ()
228247
229248 if use_numpy :
230- return [None if a is None else a .detach ().cpu ().numpy () for a in res ] # type: ignore[union-attr]
231- return res # type: ignore[return-value]
249+ return [None if a is None else a .detach ().cpu ().numpy () for a in fres ]
250+ return fres
251+
252+ def run_with_values (
253+ self , * args : Optional [torch_ops .OpRunValue ]
254+ ) -> Union [torch_ops .OpRunValue , Tuple [torch_ops .OpRunValue , ...]]:
255+ """
256+ Runs the ONNX model.
257+
258+ :param args: inputs
259+ :return: output OpRunValue
260+ """
261+ assert all (
262+ isinstance (a , torch_ops .OpRunValue ) for a in args
263+ ), f"Unexpected type in args: { [type (a ) for a in args ]} "
264+ outputs = self .output_names
265+
266+ # sets constants
267+ for k , v in self .constants .items ():
268+ r = self .runtime_info [k ]
269+ if not r .has_value :
270+ r .set_value (
271+ torch_ops .OpRunValue (
272+ v .to (self .CUDA ) if r .is_shape and self .on_cuda else v , True
273+ )
274+ )
275+
276+ # inputs
277+ for k , v in zip (self .input_names , args ):
278+ r = self .runtime_info [k ]
279+ r .set_value (torch_ops .OpRunValue (None if v is None else v .tensor ))
280+
281+ # node execution
282+ for it , kernel in enumerate (self .kernels ):
283+ if kernel is not None :
284+ # kernel execution
285+ inputs = [(self .runtime_info [i ].value if i else None ) for i in kernel .input ]
286+ res = kernel .run (* inputs )
287+ if isinstance (res , tuple ):
288+ # outputs
289+ assert all (isinstance (o , torch_ops .OpRunValue ) for o in res ), (
290+ f"Unexpected output type { [type (o ) for o in res ]} "
291+ f"for kernel { type (kernel )} ."
292+ )
293+ for name , t in zip (kernel .output , res ):
294+ self .runtime_info [name ].set_value (t )
295+ else :
296+ assert isinstance (
297+ res , torch_ops .OpRunValue
298+ ), f"Unexpected output type { type (res )} for kernel { type (kernel )} ."
299+ self .runtime_info [kernel .output [0 ]].set_value (res )
300+
301+ # free intermediate results
302+ for name in self .last_used [it ]:
303+ self .runtime_info [name ].clean_value ()
304+
305+ assert all (
306+ self .runtime_info [o ].value is not None for o in outputs
307+ ), "Not implemented yet when one output is None."
308+ res2 = [torch_ops .OpRunValue (self .runtime_info [o ].value .tensor ) for o in outputs ] # type: ignore[assignment, union-attr]
309+
310+ # clean previous execution
311+ for k in self .input_names :
312+ self .runtime_info [k ].clean_value ()
313+ for o in self .output_names :
314+ self .runtime_info [o ].clean_value ()
315+
316+ return res2 [0 ] if len (res2 ) == 1 else tuple (res2 ) # type: ignore[index, return-value, arg-type]
0 commit comments