11from typing import Dict , List , Tuple , Union
22
33
4- ReportKeyType = Union [str , Tuple [Union [int , str ], ...]]
4+ ReportKeyNameType = Union [str , Tuple [str , int , str ]]
5+ ReportKeyValueType = Tuple [int , Tuple [int , ...]]
56
67
78class ReportResultsComparison :
@@ -13,7 +14,7 @@ class ReportResultsComparison:
1314 :param tensors: tensor
1415 """
1516
16- def __init__ (self , tensors : Dict [ReportKeyType , "torch.Tensor" ]): # noqa: F821
17+ def __init__ (self , tensors : Dict [ReportKeyNameType , "torch.Tensor" ]): # noqa: F821
1718 from ..helpers .onnx_helper import dtype_to_tensor_dtype
1819 from ..helpers import max_diff
1920
@@ -22,7 +23,7 @@ def __init__(self, tensors: Dict[ReportKeyType, "torch.Tensor"]): # noqa: F821
2223 self .tensors = tensors
2324 self ._build_mapping ()
2425
25- def key (self , tensor : "torch.Tensor" ) -> ReportKeyType : # noqa: F821
26+ def key (self , tensor : "torch.Tensor" ) -> ReportKeyValueType : # noqa: F821
2627 "Returns a key for a tensor, (onnx dtype, shape)."
2728 return self .dtype_to_tensor_dtype (tensor .dtype ), tuple (map (int , tensor .shape ))
2829
@@ -41,13 +42,13 @@ def clear(self):
4142 self .report_cmp = {}
4243
4344 @property
44- def value (self ) -> Dict [Tuple [str , ReportKeyType ], Dict [str , Union [float , str ]]]:
45+ def value (self ) -> Dict [Tuple [str , ReportKeyNameType ], Dict [str , Union [float , str ]]]:
4546 "Returns the report."
4647 return self .report_cmp
4748
4849 def report (
4950 self , outputs : Dict [str , "torch.Tensor" ] # noqa: F821
50- ) -> List [Tuple [str , ReportKeyType ]]:
51+ ) -> List [Tuple [str , ReportKeyNameType , Dict [ str , Union [ float , str ]] ]]:
5152 """
5253 For every tensor in outputs, compares it to every tensor held by
5354 this class if it shares the same type and shape. The function returns
0 commit comments