Skip to content

Commit 2677308

Browse files
committed
mypy
1 parent 25d17cb commit 2677308

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,10 @@ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
168168
)
169169
cls = kernels[key]
170170
if cls.device_dependent():
171-
kernel: torch_ops.OpRun = cls(node, opset, self.default_device) # type: ignore[call-arg]
171+
kernel2: torch_ops.OpRun = cls(node, opset, self.default_device) # type: ignore[call-arg]
172172
else:
173-
kernel = cls(node, opset)
174-
self.kernels.append(kernel)
173+
kernel2 = cls(node, opset) # type: ignore[assignment]
174+
self.kernels.append(kernel2)
175175

176176
def run(
177177
self,
@@ -234,6 +234,9 @@ def run(
234234
for name in self.last_used[it]:
235235
self.runtime_info[name].clean_value()
236236

237+
assert all(
238+
self.runtime_info[o].value is not None for o in outputs
239+
), "Not implemented yet when one output is None."
237240
fres = [self.runtime_info[o].value.tensor for o in outputs]
238241

239242
# clean previous execution
@@ -299,6 +302,9 @@ def run_with_values(
299302
for name in self.last_used[it]:
300303
self.runtime_info[name].clean_value()
301304

305+
assert all(
306+
self.runtime_info[o].value is not None for o in outputs
307+
), "Not implemented yet when one output is None."
302308
res = [torch_ops.OpRunValue(self.runtime_info[o].value.tensor) for o in outputs] # type: ignore[assignment, union-attr]
303309

304310
# clean previous execution

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ disable_error_code = ["union-attr"]
4848
module = ["onnx_diagnostic.reference.torch_ops.*"]
4949
disable_error_code = ["override"]
5050

51+
[[tool.mypy.overrides]]
52+
module = ["onnx_diagnostic.reference.torch_ops._op_run"]
53+
disable_error_code = ["name-defined"]
54+
5155
[[tool.mypy.overrides]]
5256
module = ["onnx_diagnostic.torch_export_patches.*"]
5357
disable_error_code = ["arg-type", "assignment", "attr-defined", "index", "misc", "name-defined", "operator", "return-value", "union-attr"]

0 commit comments

Comments
 (0)