Skip to content

Commit 4ebd848

Browse files
committed
mypy
1 parent 58ba91e commit 4ebd848

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def __init__(
105105
self.nodes = None
106106
self.rt_inits_ = None
107107
self.rt_nodes_ = None
108-
self.local_functions = None
109108
else:
110109
self.nodes = (
111110
[self.proto]
@@ -124,19 +123,19 @@ def __init__(
124123
else {}
125124
)
126125
self.rt_nodes_ = self.nodes.copy()
127-
self.local_functions: Dict[
128-
Tuple[str, str], "OnnxruntimeEvaluator" # noqa: UP037
129-
] = (
130-
{(f.domain, f.name): self.__class__(f) for f in self.proto.functions}
131-
if hasattr(self.proto, "functions")
132-
else {}
133-
)
134-
if local_functions:
135-
self.local_functions.update(local_functions)
126+
127+
self.local_functions: Dict[Tuple[str, str], "OnnxruntimeEvaluator"] = ( # noqa: UP037
128+
{(f.domain, f.name): self.__class__(f) for f in self.proto.functions}
129+
if hasattr(self.proto, "functions")
130+
else {}
131+
)
132+
if local_functions:
133+
self.local_functions.update(local_functions)
136134

137135
@property
138136
def input_names(self) -> List[str]:
139137
"Returns input names."
138+
assert self.proto, "self.proto is empty"
140139
if isinstance(self.proto, NodeProto):
141140
return self.nodes[0].input
142141
return [
@@ -149,6 +148,7 @@ def input_names(self) -> List[str]:
149148
@property
150149
def output_names(self) -> List[str]:
151150
"Returns output names."
151+
assert self.proto, "self.proto is empty"
152152
if isinstance(self.proto, NodeProto):
153153
return self.nodes[0].output
154154
return [
@@ -218,19 +218,20 @@ def run(
218218
# runs a whole
219219
if self.sess_ is None:
220220
_, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values()))
221+
assert self.sess_, "mypy not happy"
221222
return self.sess_.run(outputs, feed_inputs)
222223
if outputs is None:
223224
outputs = self.output_names
224-
results: Dict[str, Any] = self.rt_inits_.copy()
225+
results: Dict[str, Any] = (self.rt_inits_ or {}).copy()
225226

226-
for k, v in self.rt_inits_.items():
227+
for k, v in results.items():
227228
self._log(2, " +C %s: %s", k, v)
228229
for k, v in feed_inputs.items():
229230
assert not isinstance(v, str), f"Unexpected type str for {k!r}"
230231
self._log(2, " +I %s: %s", k, v)
231232
results[k] = v
232233

233-
for node in self.rt_nodes_:
234+
for node in self.rt_nodes_ or []:
234235
self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output)
235236
for i in node.input:
236237
if i != "" and i not in results:

0 commit comments

Comments
 (0)