1- from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
1+ from typing import Any , Dict , List , Optional , Sequence , Set , Tuple , Union
22import numpy as np
33from onnx import (
4+ AttributeProto ,
45 GraphProto ,
56 FunctionProto ,
67 ModelProto ,
@@ -250,7 +251,7 @@ def run(
250251 inputs = [(results [i ] if i != "" else None ) for i in node .input ]
251252 if node .op_type == "If" and node .domain == "" :
252253 outputs = self ._run_if (node , inputs , results )
253- elif node .op_type == "Scan" and node .domain == "" :
254+ elif node .op_type in { "Scan" , "Loop" } and node .domain == "" :
254255 outputs = self ._run_scan (node , inputs , results )
255256 elif self ._is_local_function (node ):
256257 outputs = self ._run_local (node , inputs , results )
@@ -302,6 +303,27 @@ def _make_model_proto(
302303
303304 return onx
304305
306+ @classmethod
307+ def _get_hidden_inputs (self , graph : GraphProto ) -> Set [str , Any ]:
308+ """
309+ Returns the hidden inputs (inputs coming from an upper context)
310+ used by a subgraph.
311+ """
312+ hidden = set ()
313+ memo = set (i .name for i in graph .initializer )
314+ memo |= set (i .name for i in graph .sparse_initializer )
315+ for node in graph .node :
316+ for i in node .input :
317+ if i not in memo :
318+ hidden .add (i )
319+ for att in node .attribute :
320+ if att .type == AttributeProto .GRAPH and att .g :
321+ hid = self ._get_hidden_inputs (att .g )
322+ less = set (h for h in hid if h not in memo )
323+ hidden |= less
324+ memo |= set (node .output )
325+ return hidden
326+
305327 def _get_sess (
306328 self , node : Union [ModelProto , NodeProto ], inputs : List [Any ]
307329 ) -> Tuple [ModelProto , _InferenceSession ]:
@@ -341,7 +363,7 @@ def _get_sess(
341363 return onx , sess
342364
343365 def _get_sess_init_subgraph (
344- self , node : NodeProto , inputs : List [Any ], context : Dict [str , Any ]
366+ self , node : NodeProto , inputs : List [Any ], context : Dict [str , Any ], g : GraphProto
345367 ) -> List [Any ]:
346368 unique_names = set ()
347369 vinputs = []
@@ -352,8 +374,9 @@ def _get_sess_init_subgraph(
352374 value = oh .make_tensor_value_info (i , dtype_to_tensor_dtype (it .dtype ), it .shape )
353375 vinputs .append (value )
354376
377+ reduced_set = self ._get_hidden_inputs (g )
355378 for i , v in context .items ():
356- if i not in unique_names :
379+ if i in reduced_set and i not in unique_names :
357380 unique_names .add (i )
358381 value = oh .make_tensor_value_info (i , dtype_to_tensor_dtype (v .dtype ), v .shape )
359382 vinputs .append (value )
@@ -362,13 +385,12 @@ def _get_sess_init_subgraph(
362385 def _get_sess_if (
363386 self , node : NodeProto , branch : str , inputs : List [Any ], context : Dict [str , Any ]
364387 ) -> Tuple [ModelProto , "OnnxruntimeEvaluator" ]:
365- vinputs = self ._get_sess_init_subgraph (node , inputs , context )
366-
367388 g = None
368389 for att in node .attribute :
369390 if att .name == branch :
370391 g = att .g
371392 assert g , f"Missing attribute { branch !r} "
393+ vinputs = self ._get_sess_init_subgraph (node , inputs , context , g )
372394
373395 voutputs = g .output
374396
@@ -439,20 +461,21 @@ def _run_if(
439461 self ._cache [key ] = onx , sess = self ._get_sess_if (node , name , inputs , results )
440462
441463 assert hasattr (sess , "run" ), f"Missing method run for type { type (sess )} "
464+ input_names = [i .name for i in sess .get_inputs ()]
465+ feeds = {name : results [name ] for name in input_names }
442466 outputs = sess .run (None , feeds )
443467 assert isinstance (outputs , list ), f"Unexpected type for outputs { type (outputs )} "
444468 return outputs
445469
446470 def _get_sess_scan (
447471 self , node : NodeProto , branch : str , inputs : List [Any ], context : Dict [str , Any ]
448472 ) -> Tuple [ModelProto , "OnnxruntimeEvaluator" ]:
449- vinputs = self ._get_sess_init_subgraph (node , inputs , context )
450-
451473 g = None
452474 for att in node .attribute :
453475 if att .name == branch :
454476 g = att .g
455477 assert g , f"Missing attribute { branch !r} "
478+ vinputs = self ._get_sess_init_subgraph (node , inputs , context , g )
456479
457480 voutputs = []
458481 for name , goutput in zip (node .output , g .output ):
@@ -492,6 +515,8 @@ def _run_scan(
492515 self ._cache [key ] = onx , sess = self ._get_sess_scan (node , name , inputs , results )
493516
494517 assert hasattr (sess , "run" ), f"Missing method run for type { type (sess )} "
518+ input_names = [i .name for i in sess .get_inputs ()]
519+ feeds = {name : results [name ] for name in input_names }
495520 outputs = sess .run (None , feeds )
496521 assert isinstance (outputs , list ), f"Unexpected type for outputs { type (outputs )} "
497522 return outputs
0 commit comments