Skip to content

Commit 7506754

Browse files
committed
types
1 parent a0a38a8 commit 7506754

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

onnx_diagnostic/torch_models/validate.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,7 +1661,7 @@ def call_torch_export_model_builder(
16611661
return summary, data
16621662

16631663

1664-
def process_statistics(data: Sequence[Dict[str, float]]) -> Dict[str, float]:
1664+
def process_statistics(data: Sequence[Dict[str, float]]) -> Dict[str, Any]:
16651665
"""
16661666
Processes statistics coming from the exporters.
16671667
It takes a sequence of dictionaries (like a data frame)
@@ -1695,11 +1695,11 @@ def _add(d, a, v, use_max=False):
16951695
else:
16961696
d[a] += v
16971697

1698-
counts = {}
1699-
applied_pattern_time = {}
1700-
applied_pattern_n = {}
1701-
matching_pattern_time = {}
1702-
matching_pattern_n = {}
1698+
counts: Dict[str, Any] = {}
1699+
applied_pattern_time: Dict[str, Any] = {}
1700+
applied_pattern_n: Dict[str, Any] = {}
1701+
matching_pattern_time: Dict[str, Any] = {}
1702+
matching_pattern_n: Dict[str, Any] = {}
17031703

17041704
for obs in data:
17051705
pattern = _simplify(obs["pattern"])
@@ -1875,7 +1875,7 @@ def call_torch_export_custom(
18751875
if "ERR_export_onnx_c" in summary:
18761876
return summary, data
18771877

1878-
new_stat = {k: v for k, v in opt_stats.items() if k.startswith("time_")}
1878+
new_stat: Dict[str, Any] = {k: v for k, v in opt_stats.items() if k.startswith("time_")}
18791879
new_stat.update({k[5:]: v for k, v in opt_stats.items() if k.startswith("stat_time_")})
18801880
if "optimization" in opt_stats:
18811881
new_stat.update(process_statistics(opt_stats["optimization"]))

0 commit comments

Comments
 (0)