Skip to content

Commit 3460174

Browse files
committed
mypy
1 parent 82ddf97 commit 3460174

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

onnx_diagnostic/reference/report_results_comparison.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ def report(
5555
the results of the comparison. The function also collects the results
5656
into a dictionary the user can retrieve later.
5757
"""
58-
res = []
58+
res: List[Tuple[str, ReportKeyNameType, Dict[str, Union[float, str]]]] = []
5959
for name, tensor in outputs.items():
6060
key = self.key(tensor)
6161
if key not in self.mapping:
6262
continue
63-
cache = {}
63+
cache: Dict["torch.device", "torch.Tensor"] = {} # noqa: F821, UP037
6464
for held_key in self.mapping[key]:
6565
t2 = self.tensors[held_key]
6666
if hasattr(t2, "device") and hasattr(tensor, "device"):

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,9 +544,9 @@ def run(
544544
zip(
545545
kernel.output,
546546
(
547-
tuple(r.tensor for r in res)
547+
tuple((r.tensor if r else None) for r in res)
548548
if isinstance(res, tuple)
549-
else (res.tensor,)
549+
else ((res.tensor if res else None),)
550550
),
551551
)
552552
)

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ disable_error_code = ["assignment", "arg-type", "name-defined", "union-attr"]
4444
module = ["onnx_diagnostic.helpers.ort_session"]
4545
disable_error_code = ["union-attr"]
4646

47+
[[tool.mypy.overrides]]
48+
module = ["onnx_diagnostic.reference.report_results_comparison"]
49+
disable_error_code = ["name-defined"]
50+
4751
[[tool.mypy.overrides]]
4852
module = ["onnx_diagnostic.reference.torch_ops.*"]
4953
disable_error_code = ["override"]

0 commit comments

Comments
 (0)