@@ -105,7 +105,6 @@ def __init__(
105105 self .nodes = None
106106 self .rt_inits_ = None
107107 self .rt_nodes_ = None
108- self .local_functions = None
109108 else :
110109 self .nodes = (
111110 [self .proto ]
@@ -124,19 +123,19 @@ def __init__(
124123 else {}
125124 )
126125 self .rt_nodes_ = self .nodes .copy ()
127- self .local_functions : Dict [
128- Tuple [str , str ], "OnnxruntimeEvaluator" # noqa: UP037
129- ] = (
130- {(f .domain , f .name ): self .__class__ (f ) for f in self .proto .functions }
131- if hasattr (self .proto , "functions" )
132- else {}
133- )
134- if local_functions :
135- self .local_functions .update (local_functions )
126+
127+ self .local_functions : Dict [Tuple [str , str ], "OnnxruntimeEvaluator" ] = ( # noqa: UP037
128+ {(f .domain , f .name ): self .__class__ (f ) for f in self .proto .functions }
129+ if hasattr (self .proto , "functions" )
130+ else {}
131+ )
132+ if local_functions :
133+ self .local_functions .update (local_functions )
136134
137135 @property
138136 def input_names (self ) -> List [str ]:
139137 "Returns input names."
138+ assert self .proto , "self.proto is empty"
140139 if isinstance (self .proto , NodeProto ):
141140 return self .nodes [0 ].input
142141 return [
@@ -149,6 +148,7 @@ def input_names(self) -> List[str]:
149148 @property
150149 def output_names (self ) -> List [str ]:
151150 "Returns output names."
151+ assert self .proto , "self.proto is empty"
152152 if isinstance (self .proto , NodeProto ):
153153 return self .nodes [0 ].output
154154 return [
@@ -218,19 +218,20 @@ def run(
218218 # runs a whole
219219 if self .sess_ is None :
220220 _ , self .sess_ = self ._get_sess (self .proto , list (feed_inputs .values ()))
221+ assert self .sess_ , "mypy not happy"
221222 return self .sess_ .run (outputs , feed_inputs )
222223 if outputs is None :
223224 outputs = self .output_names
224- results : Dict [str , Any ] = self .rt_inits_ .copy ()
225+ results : Dict [str , Any ] = ( self .rt_inits_ or {}) .copy ()
225226
226- for k , v in self . rt_inits_ .items ():
227+ for k , v in results .items ():
227228 self ._log (2 , " +C %s: %s" , k , v )
228229 for k , v in feed_inputs .items ():
229230 assert not isinstance (v , str ), f"Unexpected type str for { k !r} "
230231 self ._log (2 , " +I %s: %s" , k , v )
231232 results [k ] = v
232233
233- for node in self .rt_nodes_ :
234+ for node in self .rt_nodes_ or [] :
234235 self ._log (1 , "%s(%s) -> %s" , node .op_type , node .input , node .output )
235236 for i in node .input :
236237 if i != "" and i not in results :
0 commit comments