5
5
# SPDX-License-Identifier: Apache-2.0
6
6
7
7
import math
8
+ import warnings
8
9
from abc import ABC , abstractmethod
9
10
from typing import Any , Callable , Dict , List , Union
10
11
@@ -202,8 +203,6 @@ def __init__(self, model, X, y=None):
202
203
"The onnx package must be installed to work with ONNX models. Please `pip install onnx`."
203
204
)
204
205
205
- # TODO: size of X should match size of graph.input
206
-
207
206
self ._model = model
208
207
self ._X = X
209
208
self ._y = y
@@ -214,10 +213,14 @@ def __init__(self, model, X, y=None):
214
213
outputs = [self ._tensor_to_dataframe (o ) for o in inferred_model .graph .output ]
215
214
216
215
if len (inputs ) > 1 :
217
- pass # TODO: warn that only first input will be captured
216
+ warnings .warn (
217
+ f"The ONNX model has { len (inputs )} inputs but only the first input will be captured in Model Manager."
218
+ )
218
219
219
220
if len (outputs ) > 1 :
220
- pass # TODO: warn that only the first output will be captured
221
+ warnings .warn (
222
+ f"The ONNX model has { len (outputs )} outputs but only the first input will be captured in Model Manager."
223
+ )
221
224
222
225
self ._X_df = inputs [0 ]
223
226
self ._y_df = outputs [0 ]
@@ -263,7 +266,7 @@ def _tensor_to_dataframe(tensor):
263
266
"""
264
267
if isinstance (tensor , onnx .onnx_ml_pb2 .ValueInfoProto ):
265
268
tensor = json_format .MessageToDict (tensor )
266
- elif not isinstance (tensor , dict ):
269
+ elif not isinstance (tensor , dict ):
267
270
raise ValueError (f"Unexpected type { type (tensor )} ." )
268
271
269
272
name = tensor .get ("name" , "Var" )
@@ -275,9 +278,13 @@ def _tensor_to_dataframe(tensor):
275
278
dtype = onnx .helper .tensor_dtype_to_np_dtype (type_ ["tensorType" ]["elemType" ])
276
279
277
280
# Tuple of tensor dimensions e.g. (1, 1, 24)
278
- input_dims = tuple (int (d ["dimValue" ]) for d in type_ ["tensorType" ]["shape" ]["dim" ])
281
+ input_dims = tuple (
282
+ int (d ["dimValue" ]) for d in type_ ["tensorType" ]["shape" ]["dim" ]
283
+ )
279
284
280
- return pd .DataFrame (dtype = dtype , columns = [f"{ name } { i + 1 } " for i in range (math .prod (input_dims ))])
285
+ return pd .DataFrame (
286
+ dtype = dtype , columns = [f"{ name } { i + 1 } " for i in range (math .prod (input_dims ))]
287
+ )
281
288
282
289
@property
283
290
def algorithm (self ) -> str :
@@ -309,7 +316,16 @@ def model(self) -> object:
309
316
310
317
@property
311
318
def model_params (self ) -> Dict [str , Any ]:
312
- return {k : getattr (self .model , k , None ) for k in ("ir_version" , "model_version" , "opset_import" , "producer_name" , "producer_version" )}
319
+ return {
320
+ k : getattr (self .model , k , None )
321
+ for k in (
322
+ "ir_version" ,
323
+ "model_version" ,
324
+ "opset_import" ,
325
+ "producer_name" ,
326
+ "producer_version" ,
327
+ )
328
+ }
313
329
314
330
@property
315
331
def predict_function (self ) -> Callable :
0 commit comments