diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index 6e25c3f8..88feb5f6 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -3,7 +3,7 @@ import os import pprint import sys -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import time import numpy as np import onnx @@ -994,6 +994,26 @@ def _validate_do_run_exported_program(data, summary, verbose, quiet): ) +_cache_export_times = [] +_main_export_function = torch.export.export + + +def _torch_export_export(*args, _export=_main_export_function, **kwargs): + begin = time.perf_counter() + res = _export(*args, **kwargs) + duration = time.perf_counter() - begin + _cache_export_times.append(duration) + return res + + +def _restore_torch_export_export(summary): + torch.export.export = _main_export_function + if _cache_export_times: + summary["time_torch_export_export"] = sum(_cache_export_times) + summary["time_torch_export_export_n"] = len(_cache_export_times) + _cache_export_times.clear() + + def call_exporter( data: Dict[str, Any], exporter: str, @@ -1019,6 +1039,9 @@ def call_exporter( :return: two dictionaries, one with some metrics, another one with whatever the function produces """ + _cache_export_times.clear() + torch.export.export = _torch_export_export + if exporter == "export" or exporter.startswith("export-"): # torch export summary, data = call_torch_export_export( @@ -1029,6 +1052,7 @@ def call_exporter( optimization=optimization, do_run=do_run, ) + _restore_torch_export_export(summary) return summary, data if exporter.startswith("onnx-"): # torch export @@ -1040,6 +1064,7 @@ def call_exporter( optimization=optimization, output_names=output_names, ) + _restore_torch_export_export(summary) return summary, data if exporter == "custom" or exporter.startswith("custom"): # torch export @@ -1052,6 +1077,7 @@ def call_exporter( dump_folder=dump_folder, output_names=output_names, ) + _restore_torch_export_export(summary) return summary, data if exporter == "modelbuilder": # torch export @@ -1063,6 +1089,7 @@ def call_exporter( optimization=optimization, output_names=output_names, ) + _restore_torch_export_export(summary) return summary, data raise NotImplementedError( f"export with {exporter!r} and optimization={optimization!r} not implemented yet, " @@ -1634,6 +1661,97 @@ def call_torch_export_model_builder( return summary, data +def process_statistics(data: Sequence[Dict[str, float]]) -> Dict[str, Any]: + """ + Processes statistics coming from the exporters. + It takes a sequence of dictionaries (like a data frame) + and extracts some metrics. + """ + + def _simplify(p): + for s in [ + "remove_unused", + "constant_folding", + "remove_identity", + "remove_duplicated_initializer", + "dynamic_dimension_naming", + "inline", + "check", + "build_graph_for_pattern", + "pattern_optimization", + ]: + if s in p or s.replace("_", "-") in p: + return s + if p.startswith(("apply_", "match_")): + return p + return "other" + + def _add(d, a, v, use_max=False): + if v: + if a not in d: + d[a] = v + elif use_max: + d[a] = max(d[a], v) + else: + d[a] += v + + counts: Dict[str, Any] = {} + applied_pattern_time: Dict[str, Any] = {} + applied_pattern_n: Dict[str, Any] = {} + matching_pattern_time: Dict[str, Any] = {} + matching_pattern_n: Dict[str, Any] = {} + + for obs in data: + pattern = _simplify(obs["pattern"]) + _add(counts, "opt_nodes_added", obs.get("added", 0)) + _add(counts, "opt_nodes_removed", obs.get("removed", 0)) + _add(counts, "opt_time_steps", obs.get("time_in", 0)) + _add(counts, "opt_n_steps", 1) + _add( + counts, + "opt_n_iteration", + max(counts.get("opt_n_iteration", 0), obs.get("iteration", 0)), + use_max=True, + ) + + if pattern.startswith("apply_"): + _add(counts, "opt_n_applied_patterns", 1) + _add(counts, "opt_time_applied_patterns", obs.get("time_in", 0)) + _add(applied_pattern_time, pattern, obs.get("time_in", 0)) + _add(applied_pattern_n, pattern, 1) + elif pattern.startswith("match_"): + _add(counts, "opt_n_matching_patterns", 1) + _add(counts, "opt_time_matching_patterns", obs.get("time_in", 0)) + _add(matching_pattern_time, pattern, obs.get("time_in", 0)) + _add(matching_pattern_n, pattern, 1) + else: + _add(counts, f"opt_time_{pattern}", obs.get("time_in", 0)) + _add(counts, f"opt_n_{pattern}", 1) + _add(counts, f"opt_nodes_added_{pattern}", obs.get("added", 0)) + _add(counts, f"opt_nodes_removed_{pattern}", obs.get("removed", 0)) + + if applied_pattern_time: + longest = max((v, k) for k, v in applied_pattern_time.items()) + counts["opt_top_time_applied_pattern"], counts["opt_top_time_applied_pattern_arg"] = ( + longest + ) + longest = max((v, k) for k, v in applied_pattern_n.items()) + counts["opt_top_n_applied_pattern"], counts["opt_top_n_applied_pattern_arg"] = longest + + if matching_pattern_time: + longest = max((v, k) for k, v in matching_pattern_time.items()) + ( + counts["opt_top_time_matching_pattern"], + counts["opt_top_time_matching_pattern_arg"], + ) = longest + longest = max((v, k) for k, v in matching_pattern_n.items()) + counts["opt_top_n_matching_pattern"], counts["opt_top_n_matching_pattern_arg"] = ( + longest + ) + counts["onnx_opt_optimized"] = 1 + return counts + + def call_torch_export_custom( data: Dict[str, Any], exporter: str, @@ -1763,67 +1881,10 @@ def call_torch_export_custom( if "ERR_export_onnx_c" in summary: return summary, data - new_stat = {} + new_stat: Dict[str, Any] = {k: v for k, v in opt_stats.items() if k.startswith("time_")} + new_stat.update({k[5:]: v for k, v in opt_stats.items() if k.startswith("stat_time_")}) if "optimization" in opt_stats: - added, removed, time_in = 0, 0, 0.0 - max_iter = 0 - applied = {} - matched = set() - n_applied = 0 - by_pattern = {} - by_pattern_n = {} - by_iter = {} - cst_added, cst_removed, cst_time_in = 0, 0, 0.0 - - for obs in opt_stats["optimization"]: - pattern = obs["pattern"] - if pattern == "constant_folding": - cst_added += obs.get("added", 0) - cst_removed += obs.get("removed", 0) - cst_time_in += obs.get("time_in", 0) - if pattern not in by_pattern: - by_pattern[pattern] = 0 - by_pattern_n[pattern] = 0 - by_iter[pattern] = 0 - time_in += obs.get("time_in", 0) - added += obs.get("added", 0) - removed += obs.get("removed", 0) - max_iter = max(max_iter, obs.get("iteration", 0)) - by_pattern[pattern] += obs.get("time_in", 0) - by_pattern_n[pattern] += obs.get("added", 0) - obs.get("removed", 0) - if not pattern.startswith("match"): - by_iter[pattern] = max(by_iter[pattern], obs.get("iteration", 0)) - p = obs["pattern"] - if p.startswith("match_"): - matched.add(p) - elif p.startswith("apply_"): - key = f"op_opt_{p}" - key2 = f"op_opt_maxiter_{p}" - if key not in applied: - applied[key] = 1 - applied[key2] = obs["iteration"] - else: - applied[key] += 1 - applied[key2] = max(obs["iteration"], applied[key2]) - n_applied += 1 - - new_stat.update( - dict( - onnx_opt_optimized=1, - op_opt_all_time_in=time_in, - op_opt_all_added=added, - op_opt_all_removed=removed, - op_opt_max_iter=max_iter, - op_opt_unique_matched=len(matched), - op_opt_unique_applied=len(applied), - op_opt_n_applied=n_applied, - time_export_optimization=time_in, - op_opt_export_optimization=time_in, - op_opt_cst_time_in=cst_time_in, - op_opt_cst_added=cst_added, - op_opt_cst_removed=cst_removed, - ) - ) + new_stat.update(process_statistics(opt_stats["optimization"])) summary.update(new_stat) assert epo is not None, "no onnx export was found"