4
4
# Copyright © 2023, SAS Institute Inc., Cary, NC, USA. All Rights Reserved.
5
5
# SPDX-License-Identifier: Apache-2.0
6
6
7
+ import math
7
8
from abc import ABC , abstractmethod
8
9
from typing import Any , Callable , Dict , List , Union
9
10
17
18
18
19
try :
19
20
import onnx
21
+
22
+ # ONNX serializes models using protobuf, so this should be safe
23
+ from google .protobuf import json_format
20
24
except ImportError :
21
25
onnx = None
22
26
@@ -198,9 +202,6 @@ def __init__(self, model, X, y=None):
198
202
"The onnx package must be installed to work with ONNX models. Please `pip install onnx`."
199
203
)
200
204
201
- # ONNX serializes models using protobuf, so this should be safe
202
- from google .protobuf import json_format
203
-
204
205
# TODO: size of X should match size of graph.input
205
206
206
207
self ._model = model
@@ -209,22 +210,18 @@ def __init__(self, model, X, y=None):
209
210
210
211
inferred_model = onnx .shape_inference .infer_shapes (model )
211
212
212
- inputs = [json_format . MessageToDict (i ) for i in inferred_model .graph .input ]
213
- outputs = [json_format . MessageToDict (o ) for o in inferred_model .graph .output ]
213
+ inputs = [self . _tensor_to_dataframe (i ) for i in inferred_model .graph .input ]
214
+ outputs = [self . _tensor_to_dataframe (o ) for o in inferred_model .graph .output ]
214
215
215
216
if len (inputs ) > 1 :
216
217
pass # TODO: warn that only first input will be captured
217
218
218
219
if len (outputs ) > 1 :
219
220
pass # TODO: warn that only the first output will be captured
220
221
221
- inputs [0 ][ "type" ][ "tensorType" ][ "elemType" ]
222
- inputs [ 0 ][ "type" ][ "tensorType" ][ "shape" ]
222
+ self . _X_df = inputs [0 ]
223
+ self . _y_df = outputs [ 0 ]
223
224
224
- self ._properties = {
225
- "description" : model .doc_string ,
226
- "opset" : model .opset_import
227
- }
228
225
# initializer (static params)
229
226
230
227
# for field in model.ListFields():
@@ -243,45 +240,76 @@ def __init__(self, model, X, y=None):
243
240
# producerVersion
244
241
# opsetImport
245
242
246
-
247
243
# # list of (FieldDescriptor, value)
248
244
# fields = model.ListFields()
249
- # inferred_model = onnx.shape_inference.infer_shapes(model)
250
- #
251
- # inputs = model.graph.input
252
- # assert len(inputs) == 1
253
- # i = inputs[0]
254
- # print(i.name)
255
- # print(i.type)
256
- # print(i.type.tensor_type.shape)
245
+
246
+ @staticmethod
247
+ def _tensor_to_dataframe (tensor ):
248
+ """
249
+
250
+ Parameters
251
+ ----------
252
+ tensor : onnx.onnx_ml_pb2.ValueInfoProto or dict
253
+ A protobuf `Message` containing information
254
+
255
+ Returns
256
+ -------
257
+ pandas.DataFrame
258
+
259
+ Examples
260
+ --------
261
+ df = _tensor_to_dataframe(model.graph.input[0])
262
+
263
+ """
264
+ if isinstance (tensor , onnx .onnx_ml_pb2 .ValueInfoProto ):
265
+ tensor = json_format .MessageToDict (tensor )
266
+ elif not isinstance (tensor , dict ):
267
+ raise ValueError (f"Unexpected type { type (tensor )} ." )
268
+
269
+ name = tensor .get ("name" , "Var" )
270
+ type_ = tensor ["type" ]
271
+
272
+ if not "tensorType" in type_ :
273
+ raise ValueError (f"Received an unexpected ONNX input type: { type_ } ." )
274
+
275
+ dtype = onnx .helper .tensor_dtype_to_np_dtype (type_ ["tensorType" ]["elemType" ])
276
+
277
+ # Tuple of tensor dimensions e.g. (1, 1, 24)
278
+ input_dims = tuple (int (d ["dimValue" ]) for d in type_ ["tensorType" ]["shape" ]["dim" ])
279
+
280
+ return pd .DataFrame (dtype = dtype , columns = [f"{ name } { i + 1 } " for i in range (math .prod (input_dims ))])
257
281
258
282
@property
259
283
def algorithm (self ) -> str :
260
284
return "neural network"
261
285
286
+ @property
287
+ def description (self ) -> str :
288
+ return self .model .doc_string
289
+
262
290
@property
263
291
def is_binary_classifier (self ) -> bool :
264
- return False
292
+ return len ( self . output_column_names ) == 2
265
293
266
294
@property
267
295
def is_classifier (self ) -> bool :
268
- return False
296
+ return len ( self . output_column_names ) > 1
269
297
270
298
@property
271
299
def is_clusterer (self ) -> bool :
272
300
return False
273
301
274
302
@property
275
303
def is_regressor (self ) -> bool :
276
- return False
304
+ return len ( self . output_column_names ) == 1
277
305
278
306
@property
279
307
def model (self ) -> object :
280
308
return self ._model
281
309
282
310
@property
283
311
def model_params (self ) -> Dict [str , Any ]:
284
- return {}
312
+ return {k : getattr ( self . model , k , None ) for k in ( "ir_version" , "model_version" , "opset_import" , "producer_name" , "producer_version" ) }
285
313
286
314
@property
287
315
def predict_function (self ) -> Callable :
@@ -300,12 +328,12 @@ def threshold(self) -> Union[str, None]:
300
328
return None
301
329
302
330
@property
303
- def X (self ):
304
- return self ._X
331
+ def X (self ) -> pd . DataFrame :
332
+ return self ._X_df
305
333
306
334
@property
307
- def y (self ):
308
- return self ._y
335
+ def y (self ) -> pd . DataFrame :
336
+ return self ._y_df
309
337
310
338
311
339
class PyTorchModelInfo (ModelInfo ):
0 commit comments