Skip to content

Commit 91f9b15

Browse files
committed
feat: extract onnx info
1 parent 06939a9 commit 91f9b15

File tree

2 files changed

+70
-31
lines changed

2 files changed

+70
-31
lines changed

src/sasctl/utils/model_info.py

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Copyright © 2023, SAS Institute Inc., Cary, NC, USA. All Rights Reserved.
55
# SPDX-License-Identifier: Apache-2.0
66

7+
import math
78
from abc import ABC, abstractmethod
89
from typing import Any, Callable, Dict, List, Union
910

@@ -17,6 +18,9 @@
1718

1819
try:
1920
import onnx
21+
22+
# ONNX serializes models using protobuf, so this should be safe
23+
from google.protobuf import json_format
2024
except ImportError:
2125
onnx = None
2226

@@ -198,9 +202,6 @@ def __init__(self, model, X, y=None):
198202
"The onnx package must be installed to work with ONNX models. Please `pip install onnx`."
199203
)
200204

201-
# ONNX serializes models using protobuf, so this should be safe
202-
from google.protobuf import json_format
203-
204205
# TODO: size of X should match size of graph.input
205206

206207
self._model = model
@@ -209,22 +210,18 @@ def __init__(self, model, X, y=None):
209210

210211
inferred_model = onnx.shape_inference.infer_shapes(model)
211212

212-
inputs = [json_format.MessageToDict(i) for i in inferred_model.graph.input]
213-
outputs = [json_format.MessageToDict(o) for o in inferred_model.graph.output]
213+
inputs = [self._tensor_to_dataframe(i) for i in inferred_model.graph.input]
214+
outputs = [self._tensor_to_dataframe(o) for o in inferred_model.graph.output]
214215

215216
if len(inputs) > 1:
216217
pass # TODO: warn that only first input will be captured
217218

218219
if len(outputs) > 1:
219220
pass # TODO: warn that only the first output will be captured
220221

221-
inputs[0]["type"]["tensorType"]["elemType"]
222-
inputs[0]["type"]["tensorType"]["shape"]
222+
self._X_df = inputs[0]
223+
self._y_df = outputs[0]
223224

224-
self._properties = {
225-
"description": model.doc_string,
226-
"opset": model.opset_import
227-
}
228225
# initializer (static params)
229226

230227
# for field in model.ListFields():
@@ -243,45 +240,76 @@ def __init__(self, model, X, y=None):
243240
# producerVersion
244241
# opsetImport
245242

246-
247243
# # list of (FieldDescriptor, value)
248244
# fields = model.ListFields()
249-
# inferred_model = onnx.shape_inference.infer_shapes(model)
250-
#
251-
# inputs = model.graph.input
252-
# assert len(inputs) == 1
253-
# i = inputs[0]
254-
# print(i.name)
255-
# print(i.type)
256-
# print(i.type.tensor_type.shape)
245+
246+
@staticmethod
247+
def _tensor_to_dataframe(tensor):
248+
"""
249+
250+
Parameters
251+
----------
252+
tensor : onnx.onnx_ml_pb2.ValueInfoProto or dict
253+
A protobuf `Message` containing information
254+
255+
Returns
256+
-------
257+
pandas.DataFrame
258+
259+
Examples
260+
--------
261+
df = _tensor_to_dataframe(model.graph.input[0])
262+
263+
"""
264+
if isinstance(tensor, onnx.onnx_ml_pb2.ValueInfoProto):
265+
tensor = json_format.MessageToDict(tensor)
266+
elif not isinstance(tensor, dict):
267+
raise ValueError(f"Unexpected type {type(tensor)}.")
268+
269+
name = tensor.get("name", "Var")
270+
type_ = tensor["type"]
271+
272+
if not "tensorType" in type_:
273+
raise ValueError(f"Received an unexpected ONNX input type: {type_}.")
274+
275+
dtype = onnx.helper.tensor_dtype_to_np_dtype(type_["tensorType"]["elemType"])
276+
277+
# Tuple of tensor dimensions e.g. (1, 1, 24)
278+
input_dims = tuple(int(d["dimValue"]) for d in type_["tensorType"]["shape"]["dim"])
279+
280+
return pd.DataFrame(dtype=dtype, columns=[f"{name}{i+1}" for i in range(math.prod(input_dims))])
257281

258282
@property
259283
def algorithm(self) -> str:
260284
return "neural network"
261285

286+
@property
287+
def description(self) -> str:
288+
return self.model.doc_string
289+
262290
@property
263291
def is_binary_classifier(self) -> bool:
264-
return False
292+
return len(self.output_column_names) == 2
265293

266294
@property
267295
def is_classifier(self) -> bool:
268-
return False
296+
return len(self.output_column_names) > 1
269297

270298
@property
271299
def is_clusterer(self) -> bool:
272300
return False
273301

274302
@property
275303
def is_regressor(self) -> bool:
276-
return False
304+
return len(self.output_column_names) == 1
277305

278306
@property
279307
def model(self) -> object:
280308
return self._model
281309

282310
@property
283311
def model_params(self) -> Dict[str, Any]:
284-
return {}
312+
return {k: getattr(self.model, k, None) for k in ("ir_version", "model_version", "opset_import", "producer_name", "producer_version")}
285313

286314
@property
287315
def predict_function(self) -> Callable:
@@ -300,12 +328,12 @@ def threshold(self) -> Union[str, None]:
300328
return None
301329

302330
@property
303-
def X(self):
304-
return self._X
331+
def X(self) -> pd.DataFrame:
332+
return self._X_df
305333

306334
@property
307-
def y(self):
308-
return self._y
335+
def y(self) -> pd.DataFrame:
336+
return self._y_df
309337

310338

311339
class PyTorchModelInfo(ModelInfo):

tests/unit/test_model_info_onnx.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
import pandas as pd
12
import pytest
23

3-
import sasctl.utils.model_info
4-
54
onnx = pytest.importorskip("onnx")
65
torch = pytest.importorskip("torch")
76

7+
import sasctl.utils.model_info
88
from sasctl.utils import get_model_info
99

1010
# mnist
@@ -41,4 +41,15 @@ def forward(self, x):
4141
def test_get_info(mnist_model):
4242
info = get_model_info(*mnist_model)
4343
assert isinstance(info, sasctl.utils.model_info.OnnxModelInfo)
44-
print(mnist_model)
44+
45+
# Output be classification into 10 digits
46+
assert len(info.output_column_names) == 10
47+
assert all(c.startswith("digit") for c in info.output_column_names)
48+
49+
assert isinstance(info.X, pd.DataFrame)
50+
assert len(info.X.columns) == 28 * 28
51+
52+
assert info.is_classifier
53+
assert not info.is_binary_classifier
54+
assert not info.is_regressor
55+
assert not info.is_clusterer

0 commit comments

Comments
 (0)