15
15
except ImportError :
16
16
torch = None
17
17
18
+ try :
19
+ import onnx
20
+ except ImportError :
21
+ onnx = None
22
+
23
+ try :
24
+ import onnxruntime
25
+ except ImportError :
26
+ onnxruntime = None
27
+
18
28
19
29
def get_model_info (model , X , y = None ):
20
30
"""Extracts metadata about the model and associated data sets.
@@ -40,9 +50,12 @@ def get_model_info(model, X, y=None):
40
50
"""
41
51
42
52
# Don't need to import sklearn, just check if the class is part of that module.
43
- if model . __class__ .__module__ .startswith ("sklearn." ):
53
+ if type ( model ) .__module__ .startswith ("sklearn." ):
44
54
return SklearnModelInfo (model , X , y )
45
55
56
+ if type (model ).__module__ .startswith ("onnx" ):
57
+ return _load_onnx_model (model , X , y )
58
+
46
59
# Most PyTorch models are actually subclasses of torch.nn.Module, so checking module
47
60
# name alone is not sufficient.
48
61
elif torch and isinstance (model , torch .nn .Module ):
@@ -51,17 +64,29 @@ def get_model_info(model, X, y=None):
51
64
raise ValueError (f"Unrecognized model type { type (model )} received." )
52
65
53
66
67
+ def _load_onnx_model (model , X , y = None ):
68
+ # TODO: unncessary? static analysis of onnx file sufficient?
69
+ if onnxruntime :
70
+ return OnnxModelInfo (model , X , y )
71
+
72
+ return OnnxModelInfo (model , X , y )
73
+
74
+
54
75
class ModelInfo (ABC ):
55
76
"""Base class for storing model metadata.
56
77
57
78
Attributes
58
79
----------
59
80
algorithm : str
81
+ Will appear in the "Algorithm" drop-down menu in Model Manager.
82
+ Example: "Forest", "Neural networks", "Binning", etc.
60
83
analytic_function : str
84
+ Will appear in the "Function" drop-down menu in Model Manager.
85
+ Example: "Classification", "Clustering", "Prediction"
61
86
is_binary_classifier : bool
62
- is_classifier
63
- is_regressor
64
- is_clusterer
87
+ is_classifier : bool
88
+ is_regressor : bool
89
+ is_clusterer : bool
65
90
model : object
66
91
The model instance that the information was extracted from.
67
92
model_params : {str: any}
@@ -166,13 +191,130 @@ def y(self) -> pd.DataFrame:
166
191
return
167
192
168
193
194
+ class OnnxModelInfo (ModelInfo ):
195
+ def __init__ (self , model , X , y = None ):
196
+ if onnx is None :
197
+ raise RuntimeError (
198
+ "The onnx package must be installed to work with ONNX models. Please `pip install onnx`."
199
+ )
200
+
201
+ # ONNX serializes models using protobuf, so this should be safe
202
+ from google .protobuf import json_format
203
+
204
+ # TODO: size of X should match size of graph.input
205
+
206
+ self ._model = model
207
+ self ._X = X
208
+ self ._y = y
209
+
210
+ inferred_model = onnx .shape_inference .infer_shapes (model )
211
+
212
+ inputs = [json_format .MessageToDict (i ) for i in inferred_model .graph .input ]
213
+ outputs = [json_format .MessageToDict (o ) for o in inferred_model .graph .output ]
214
+
215
+ if len (inputs ) > 1 :
216
+ pass # TODO: warn that only first input will be captured
217
+
218
+ if len (outputs ) > 1 :
219
+ pass # TODO: warn that only the first output will be captured
220
+
221
+ inputs [0 ]["type" ]["tensorType" ]["elemType" ]
222
+ inputs [0 ]["type" ]["tensorType" ]["shape" ]
223
+
224
+ self ._properties = {
225
+ "description" : model .doc_string ,
226
+ "opset" : model .opset_import
227
+ }
228
+ # initializer (static params)
229
+
230
+ # for field in model.ListFields():
231
+ # doc_string
232
+ # domain
233
+ # metadata_props
234
+ # model_author
235
+ # model_license
236
+ # model_version
237
+ # producer_name
238
+ # producer_version
239
+ # training_info
240
+
241
+ # irVersion
242
+ # producerName
243
+ # producerVersion
244
+ # opsetImport
245
+
246
+
247
+ # # list of (FieldDescriptor, value)
248
+ # fields = model.ListFields()
249
+ # inferred_model = onnx.shape_inference.infer_shapes(model)
250
+ #
251
+ # inputs = model.graph.input
252
+ # assert len(inputs) == 1
253
+ # i = inputs[0]
254
+ # print(i.name)
255
+ # print(i.type)
256
+ # print(i.type.tensor_type.shape)
257
+
258
+ @property
259
+ def algorithm (self ) -> str :
260
+ return "neural network"
261
+
262
+ @property
263
+ def is_binary_classifier (self ) -> bool :
264
+ return False
265
+
266
+ @property
267
+ def is_classifier (self ) -> bool :
268
+ return False
269
+
270
+ @property
271
+ def is_clusterer (self ) -> bool :
272
+ return False
273
+
274
+ @property
275
+ def is_regressor (self ) -> bool :
276
+ return False
277
+
278
+ @property
279
+ def model (self ) -> object :
280
+ return self ._model
281
+
282
+ @property
283
+ def model_params (self ) -> Dict [str , Any ]:
284
+ return {}
285
+
286
+ @property
287
+ def predict_function (self ) -> Callable :
288
+ return None
289
+
290
+ @property
291
+ def target_column (self ):
292
+ return None
293
+
294
+ @property
295
+ def target_values (self ):
296
+ return None
297
+
298
+ @property
299
+ def threshold (self ) -> Union [str , None ]:
300
+ return None
301
+
302
+ @property
303
+ def X (self ):
304
+ return self ._X
305
+
306
+ @property
307
+ def y (self ):
308
+ return self ._y
309
+
310
+
169
311
class PyTorchModelInfo (ModelInfo ):
170
312
"""Stores model information for a PyTorch model instance."""
171
313
172
314
def __init__ (self , model , X , y = None ):
173
315
if torch is None :
174
316
raise RuntimeError (
175
- "The PyTorch library must be installed to work with PyTorch models. Please `pip install torch`."
317
+ "The PyTorch package must be installed to work with PyTorch models. Please `pip install torch`."
176
318
)
177
319
178
320
if not isinstance (model , torch .nn .Module ):
0 commit comments