@@ -1661,7 +1661,7 @@ def call_torch_export_model_builder(
16611661 return summary , data
16621662
16631663
1664- def process_statistics (data : Sequence [Dict [str , float ]]) -> Dict [str , float ]:
1664+ def process_statistics (data : Sequence [Dict [str , float ]]) -> Dict [str , Any ]:
16651665 """
16661666 Processes statistics coming from the exporters.
16671667 It takes a sequence of dictionaries (like a data frame)
@@ -1695,11 +1695,11 @@ def _add(d, a, v, use_max=False):
16951695 else :
16961696 d [a ] += v
16971697
1698- counts = {}
1699- applied_pattern_time = {}
1700- applied_pattern_n = {}
1701- matching_pattern_time = {}
1702- matching_pattern_n = {}
1698+ counts : Dict [ str , Any ] = {}
1699+ applied_pattern_time : Dict [ str , Any ] = {}
1700+ applied_pattern_n : Dict [ str , Any ] = {}
1701+ matching_pattern_time : Dict [ str , Any ] = {}
1702+ matching_pattern_n : Dict [ str , Any ] = {}
17031703
17041704 for obs in data :
17051705 pattern = _simplify (obs ["pattern" ])
@@ -1875,7 +1875,7 @@ def call_torch_export_custom(
18751875 if "ERR_export_onnx_c" in summary :
18761876 return summary , data
18771877
1878- new_stat = {k : v for k , v in opt_stats .items () if k .startswith ("time_" )}
1878+ new_stat : Dict [ str , Any ] = {k : v for k , v in opt_stats .items () if k .startswith ("time_" )}
18791879 new_stat .update ({k [5 :]: v for k , v in opt_stats .items () if k .startswith ("stat_time_" )})
18801880 if "optimization" in opt_stats :
18811881 new_stat .update (process_statistics (opt_stats ["optimization" ]))
0 commit comments