Skip to content

Commit 06939a9

Browse files
committed
feat: initial onnx model info
1 parent ae0639e commit 06939a9

File tree

2 files changed

+191
-5
lines changed

2 files changed

+191
-5
lines changed

src/sasctl/utils/model_info.py

Lines changed: 147 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@
1515
except ImportError:
1616
torch = None
1717

18+
try:
19+
import onnx
20+
except ImportError:
21+
onnx = None
22+
23+
try:
24+
import onnxruntime
25+
except ImportError:
26+
onnxruntime = None
27+
1828

1929
def get_model_info(model, X, y=None):
2030
"""Extracts metadata about the model and associated data sets.
@@ -40,9 +50,12 @@ def get_model_info(model, X, y=None):
4050
"""
4151

4252
# Don't need to import sklearn, just check if the class is part of that module.
43-
if model.__class__.__module__.startswith("sklearn."):
53+
if type(model).__module__.startswith("sklearn."):
4454
return SklearnModelInfo(model, X, y)
4555

56+
if type(model).__module__.startswith("onnx"):
57+
return _load_onnx_model(model, X, y)
58+
4659
# Most PyTorch models are actually subclasses of torch.nn.Module, so checking module
4760
# name alone is not sufficient.
4861
elif torch and isinstance(model, torch.nn.Module):
@@ -51,17 +64,29 @@ def get_model_info(model, X, y=None):
5164
raise ValueError(f"Unrecognized model type {type(model)} received.")
5265

5366

67+
def _load_onnx_model(model, X, y=None):
68+
# TODO: unncessary? static analysis of onnx file sufficient?
69+
if onnxruntime:
70+
return OnnxModelInfo(model, X, y)
71+
72+
return OnnxModelInfo(model, X, y)
73+
74+
5475
class ModelInfo(ABC):
5576
"""Base class for storing model metadata.
5677
5778
Attributes
5879
----------
5980
algorithm : str
81+
Will appear in the "Algorithm" drop-down menu in Model Manager.
82+
Example: "Forest", "Neural networks", "Binning", etc.
6083
analytic_function : str
84+
Will appear in the "Function" drop-down menu in Model Manager.
85+
Example: "Classification", "Clustering", "Prediction"
6186
is_binary_classifier : bool
62-
is_classifier
63-
is_regressor
64-
is_clusterer
87+
is_classifier : bool
88+
is_regressor : bool
89+
is_clusterer : bool
6590
model : object
6691
The model instance that the information was extracted from.
6792
model_params : {str: any}
@@ -166,13 +191,130 @@ def y(self) -> pd.DataFrame:
166191
return
167192

168193

194+
class OnnxModelInfo(ModelInfo):
195+
def __init__(self, model, X, y=None):
196+
if onnx is None:
197+
raise RuntimeError(
198+
"The onnx package must be installed to work with ONNX models. Please `pip install onnx`."
199+
)
200+
201+
# ONNX serializes models using protobuf, so this should be safe
202+
from google.protobuf import json_format
203+
204+
# TODO: size of X should match size of graph.input
205+
206+
self._model = model
207+
self._X = X
208+
self._y = y
209+
210+
inferred_model = onnx.shape_inference.infer_shapes(model)
211+
212+
inputs = [json_format.MessageToDict(i) for i in inferred_model.graph.input]
213+
outputs = [json_format.MessageToDict(o) for o in inferred_model.graph.output]
214+
215+
if len(inputs) > 1:
216+
pass # TODO: warn that only first input will be captured
217+
218+
if len(outputs) > 1:
219+
pass # TODO: warn that only the first output will be captured
220+
221+
inputs[0]["type"]["tensorType"]["elemType"]
222+
inputs[0]["type"]["tensorType"]["shape"]
223+
224+
self._properties = {
225+
"description": model.doc_string,
226+
"opset": model.opset_import
227+
}
228+
# initializer (static params)
229+
230+
# for field in model.ListFields():
231+
# doc_string
232+
# domain
233+
# metadata_props
234+
# model_author
235+
# model_license
236+
# model_version
237+
# producer_name
238+
# producer_version
239+
# training_info
240+
241+
# irVersion
242+
# producerName
243+
# producerVersion
244+
# opsetImport
245+
246+
247+
# # list of (FieldDescriptor, value)
248+
# fields = model.ListFields()
249+
# inferred_model = onnx.shape_inference.infer_shapes(model)
250+
#
251+
# inputs = model.graph.input
252+
# assert len(inputs) == 1
253+
# i = inputs[0]
254+
# print(i.name)
255+
# print(i.type)
256+
# print(i.type.tensor_type.shape)
257+
258+
@property
259+
def algorithm(self) -> str:
260+
return "neural network"
261+
262+
@property
263+
def is_binary_classifier(self) -> bool:
264+
return False
265+
266+
@property
267+
def is_classifier(self) -> bool:
268+
return False
269+
270+
@property
271+
def is_clusterer(self) -> bool:
272+
return False
273+
274+
@property
275+
def is_regressor(self) -> bool:
276+
return False
277+
278+
@property
279+
def model(self) -> object:
280+
return self._model
281+
282+
@property
283+
def model_params(self) -> Dict[str, Any]:
284+
return {}
285+
286+
@property
287+
def predict_function(self) -> Callable:
288+
return None
289+
290+
@property
291+
def target_column(self):
292+
return None
293+
294+
@property
295+
def target_values(self):
296+
return None
297+
298+
@property
299+
def threshold(self) -> Union[str, None]:
300+
return None
301+
302+
@property
303+
def X(self):
304+
return self._X
305+
306+
@property
307+
def y(self):
308+
return self._y
309+
310+
169311
class PyTorchModelInfo(ModelInfo):
170312
"""Stores model information for a PyTorch model instance."""
171313

172314
def __init__(self, model, X, y=None):
173315
if torch is None:
174316
raise RuntimeError(
175-
"The PyTorch library must be installed to work with PyTorch models. Please `pip install torch`."
317+
"The PyTorch package must be installed to work with PyTorch models. Please `pip install torch`."
176318
)
177319

178320
if not isinstance(model, torch.nn.Module):

tests/unit/test_model_info_onnx.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import pytest
2+
3+
import sasctl.utils.model_info
4+
5+
onnx = pytest.importorskip("onnx")
6+
torch = pytest.importorskip("torch")
7+
8+
from sasctl.utils import get_model_info
9+
10+
# mnist
11+
# get input/output shapes
12+
# get var names if available
13+
# classification/regression/etc
14+
#
15+
16+
@pytest.fixture
17+
def mnist_model(tmp_path):
18+
class Net(torch.nn.Module):
19+
def __init__(self):
20+
super(Net, self).__init__()
21+
self.fc1 = torch.nn.Linear(14 * 14, 128)
22+
self.fc2 = torch.nn.Linear(128, 10)
23+
24+
def forward(self, x):
25+
x = torch.nn.functional.max_pool2d(x, 2)
26+
x = x.reshape(-1, 1 * 14 * 14)
27+
x = self.fc1(x)
28+
x = torch.nn.functional.relu(x)
29+
x = self.fc2(x)
30+
output = torch.nn.functional.softmax(x, dim=1)
31+
return output
32+
33+
model = Net()
34+
35+
path = tmp_path / "model.onnx"
36+
X = torch.randn(1, 1, 28, 28)
37+
torch.onnx.export(model, X, path, input_names=["image"], output_names=["digit"])
38+
yield onnx.load(path), X
39+
40+
41+
def test_get_info(mnist_model):
42+
info = get_model_info(*mnist_model)
43+
assert isinstance(info, sasctl.utils.model_info.OnnxModelInfo)
44+
print(mnist_model)

0 commit comments

Comments
 (0)