1- from typing import Dict , List , Tuple , Union
1+ from typing import Any , Dict , List , Tuple , Union
22
33
44ReportKeyNameType = Union [str , Tuple [str , int , str ]]
@@ -9,7 +9,7 @@ class ReportResultsComparison:
99 """
1010 Holds tensors a runtime can use as a reference to compare
1111 intermediate results.
12- See :meth:`onnx_diagnostic.reference.TorchOnnxEvalutor .run`.
12+ See :meth:`onnx_diagnostic.reference.TorchOnnxEvaluator .run`.
1313
1414 :param tensors: tensor
1515 """
@@ -40,23 +40,37 @@ def _build_mapping(self):
4040 def clear (self ):
4141 """Clears the last report."""
4242 self .report_cmp = {}
43+ self .unique_run_names = set ()
4344
4445 @property
4546 def value (self ) -> Dict [Tuple [str , ReportKeyNameType ], Dict [str , Union [float , str ]]]:
4647 "Returns the report."
4748 return self .report_cmp
4849
50+ @property
51+ def data (self ) -> List [Dict [str , Any ]]:
52+ "Returns data which can be consumed by a dataframe."
53+ rows = []
54+ for k , v in self .value .items ():
55+ (i_run , run_name ), ref_name = k
56+ d = dict (run_index = i_run , run_name = run_name , ref_name = ref_name )
57+ d .update (v )
58+ rows .append (d )
59+ return rows
60+
4961 def report (
5062 self , outputs : Dict [str , "torch.Tensor" ] # noqa: F821
51- ) -> List [Tuple [str , ReportKeyNameType , Dict [str , Union [float , str ]]]]:
63+ ) -> List [Tuple [Tuple [ int , str ] , ReportKeyNameType , Dict [str , Union [float , str ]]]]:
5264 """
5365 For every tensor in outputs, compares it to every tensor held by
5466 this class if it shares the same type and shape. The function returns
5567 the results of the comparison. The function also collects the results
5668 into a dictionary the user can retrieve later.
5769 """
58- res : List [Tuple [str , ReportKeyNameType , Dict [str , Union [float , str ]]]] = []
70+ res : List [Tuple [Tuple [ int , str ] , ReportKeyNameType , Dict [str , Union [float , str ]]]] = []
5971 for name , tensor in outputs .items ():
72+ i_run = len (self .unique_run_names )
73+ self .unique_run_names .add (name )
6074 key = self .key (tensor )
6175 if key not in self .mapping :
6276 continue
@@ -71,6 +85,6 @@ def report(
7185 diff = self .max_diff (t , t2 )
7286 else :
7387 diff = self .max_diff (tensor , t2 )
74- res .append ((name , held_key , diff )) # type: ignore[arg-type]
75- self .report_cmp [name , held_key ] = diff
88+ res .append ((i_run , name , held_key , diff )) # type: ignore[arg-type]
89+ self .report_cmp [( i_run , name ) , held_key ] = diff
7690 return res
0 commit comments