Skip to content

Commit 3cc290a

Browse files
committed
add statistic on types
1 parent 20e7f28 commit 3cc290a

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

onnx_diagnostic/torch_models/validate.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,7 @@ def validate_model(
712712
print(f"[validate_model] done (dump onnx) in {duration}")
713713
data["onnx_filename"] = onnx_filename
714714
summary["time_onnx_save"] = duration
715+
summary.update(compute_statistics(onnx_filename))
715716
if verbose:
716717
print(f"[validate_model] dumps statistics in {dump_folder!r}...")
717718
dump_stats = os.path.join(dump_folder, f"{folder_name}.stats")
@@ -815,6 +816,39 @@ def validate_model(
815816
return summary, data
816817

817818

819+
def compute_statistics(onnx_filename: str) -> Dict[str, float]:
820+
"""Computes some statistics on the model itself."""
821+
onx = onnx.load(onnx_filename, load_external_data=False)
822+
823+
def node_iter(proto):
824+
if isinstance(proto, onnx.ModelProto):
825+
yield from node_iter(proto.graph)
826+
for f in proto.functions:
827+
yield from node_iter(f)
828+
elif isinstance(proto, (onnx.FunctionProto, onnx.GraphProto)):
829+
for node in proto.node:
830+
yield node
831+
for att in node.attribute:
832+
if att.type == onnx.AttributeProto.GRAPH:
833+
yield from att.g
834+
if hasattr(proto, "initializer"):
835+
yield from proto.initializer
836+
else:
837+
raise NotImplementedError(f"Unexpected type={type(proto)}")
838+
839+
counts = {}
840+
for proto in node_iter(onx):
841+
if isinstance(proto, onnx.NodeProto):
842+
key = f"n_node_{proto.op_type}"
843+
else:
844+
key = f"n_node_initializer_{proto.data_type}"
845+
846+
if key not in counts:
847+
counts[key] = 0
848+
counts[key] += 1
849+
return counts
850+
851+
818852
def _validate_do_run_model(
819853
data, summary, key, tag, expected_tag, verbose, repeat, warmup, quiet
820854
):

0 commit comments

Comments
 (0)