Skip to content

Commit 82ddf97

Browse files
committed
mypy
1 parent 6c88174 commit 82ddf97

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

onnx_diagnostic/reference/report_results_comparison.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import Dict, List, Tuple, Union
22

33

4-
ReportKeyType = Union[str, Tuple[Union[int, str], ...]]
4+
ReportKeyNameType = Union[str, Tuple[str, int, str]]
5+
ReportKeyValueType = Tuple[int, Tuple[int, ...]]
56

67

78
class ReportResultsComparison:
@@ -13,7 +14,7 @@ class ReportResultsComparison:
1314
:param tensors: tensor
1415
"""
1516

16-
def __init__(self, tensors: Dict[ReportKeyType, "torch.Tensor"]): # noqa: F821
17+
def __init__(self, tensors: Dict[ReportKeyNameType, "torch.Tensor"]): # noqa: F821
1718
from ..helpers.onnx_helper import dtype_to_tensor_dtype
1819
from ..helpers import max_diff
1920

@@ -22,7 +23,7 @@ def __init__(self, tensors: Dict[ReportKeyType, "torch.Tensor"]): # noqa: F821
2223
self.tensors = tensors
2324
self._build_mapping()
2425

25-
def key(self, tensor: "torch.Tensor") -> ReportKeyType: # noqa: F821
26+
def key(self, tensor: "torch.Tensor") -> ReportKeyValueType: # noqa: F821
2627
"Returns a key for a tensor, (onnx dtype, shape)."
2728
return self.dtype_to_tensor_dtype(tensor.dtype), tuple(map(int, tensor.shape))
2829

@@ -41,13 +42,13 @@ def clear(self):
4142
self.report_cmp = {}
4243

4344
@property
44-
def value(self) -> Dict[Tuple[str, ReportKeyType], Dict[str, Union[float, str]]]:
45+
def value(self) -> Dict[Tuple[str, ReportKeyNameType], Dict[str, Union[float, str]]]:
4546
"Returns the report."
4647
return self.report_cmp
4748

4849
def report(
4950
self, outputs: Dict[str, "torch.Tensor"] # noqa: F821
50-
) -> List[Tuple[str, ReportKeyType]]:
51+
) -> List[Tuple[str, ReportKeyNameType, Dict[str, Union[float, str]]]]:
5152
"""
5253
For every tensor in outputs, compares it to every tensor held by
5354
this class if it shares the same type and shape. The function returns

0 commit comments

Comments
 (0)