1010 ValueInfoProto ,
1111 helper as oh ,
1212 load ,
13+ save as onnx_save ,
14+ shape_inference as shi ,
1315)
1416from onnx .defs import onnx_opset_version
1517import onnxruntime
2022 InferenceSessionForNumpy ,
2123 _InferenceSession ,
2224)
25+ from .evaluator import ExtendedReferenceEvaluator
2326
2427PROTO = (FunctionProto , ModelProto , GraphProto , NodeProto )
2528Proto = Union [FunctionProto , ModelProto , GraphProto , NodeProto ]
@@ -304,7 +307,7 @@ def _make_model_proto(
304307 return onx
305308
306309 @classmethod
307- def _get_hidden_inputs (self , graph : GraphProto ) -> Set [str , Any ]:
310+ def _get_hidden_inputs (self , graph : GraphProto ) -> Set [str ]:
308311 """
309312 Returns the hidden inputs (inputs coming from an upper context)
310313 used by a subgraph.
@@ -331,18 +334,34 @@ def _get_sess(
331334 onx = node
332335 else :
333336 assert isinstance (node , NodeProto ), f"Unexpected type { type (node )} for node"
334- unique_names = set ()
335- vinputs = []
336- for i , it in zip (node .input , inputs ):
337- if i == "" or i in unique_names :
338- continue
339- unique_names .add (i )
340- value = oh .make_tensor_value_info (i , dtype_to_tensor_dtype (it .dtype ), it .shape )
341- vinputs .append (value )
337+ if node .op_type == "Constant" :
338+ # We force the type to be a boolean.
339+ ref = ExtendedReferenceEvaluator (node )
340+ cst = ref .run (None , {})[0 ]
341+ vinputs = []
342+ voutputs = [
343+ oh .make_tensor_value_info (
344+ node .output [0 ], dtype_to_tensor_dtype (cst .dtype ), cst .shape
345+ )
346+ ]
347+ else :
348+ unique_names = set ()
349+ vinputs = []
350+ for i , it in zip (node .input , inputs ):
351+ if i == "" or i in unique_names :
352+ continue
353+ unique_names .add (i )
354+ value = oh .make_tensor_value_info (
355+ i , dtype_to_tensor_dtype (it .dtype ), it .shape
356+ )
357+ vinputs .append (value )
358+
359+ # no need to run shape inference
360+ voutputs = [oh .make_value_info (o , TypeProto ()) for o in node .output ]
342361
343- # no need to run shape inference
344- voutputs = [oh .make_value_info (o , TypeProto ()) for o in node .output ]
345362 onx = self ._make_model_proto ([node ], vinputs , voutputs )
363+ # That helps fixing bugs.
364+ onx = shi .infer_shapes (onx )
346365
347366 cls = (
348367 InferenceSessionForNumpy
@@ -356,6 +375,7 @@ def _get_sess(
356375 onnxruntime .capi .onnxruntime_pybind11_state .InvalidGraph ,
357376 onnxruntime .capi .onnxruntime_pybind11_state .InvalidArgument ,
358377 ) as e :
378+ onnx_save (onx , "_debug_OnnxruntimeEvaluator_last_failure.onnx" )
359379 raise RuntimeError (
360380 f"Unable to infer a session with inputs\n { string_type (inputs )} "
361381 f"\n due to { e } \n { pretty_onnx (onx )} "
@@ -461,8 +481,7 @@ def _run_if(
461481 self ._cache [key ] = onx , sess = self ._get_sess_if (node , name , inputs , results )
462482
463483 assert hasattr (sess , "run" ), f"Missing method run for type { type (sess )} "
464- input_names = [i .name for i in sess .get_inputs ()]
465- feeds = {name : results [name ] for name in input_names }
484+ feeds = {name : results [name ] for name in sess .input_names }
466485 outputs = sess .run (None , feeds )
467486 assert isinstance (outputs , list ), f"Unexpected type for outputs { type (outputs )} "
468487 return outputs
@@ -515,8 +534,7 @@ def _run_scan(
515534 self ._cache [key ] = onx , sess = self ._get_sess_scan (node , name , inputs , results )
516535
517536 assert hasattr (sess , "run" ), f"Missing method run for type { type (sess )} "
518- input_names = [i .name for i in sess .get_inputs ()]
519- feeds = {name : results [name ] for name in input_names }
537+ feeds = {name : results [name ] for name in sess .input_names }
520538 outputs = sess .run (None , feeds )
521539 assert isinstance (outputs , list ), f"Unexpected type for outputs { type (outputs )} "
522540 return outputs
0 commit comments