Skip to content

Commit 51ef8ae

Browse files
committed
chore: black
1 parent 0880489 commit 51ef8ae

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

src/sasctl/utils/model_info.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# SPDX-License-Identifier: Apache-2.0
66

77
import math
8+
import warnings
89
from abc import ABC, abstractmethod
910
from typing import Any, Callable, Dict, List, Union
1011

@@ -202,8 +203,6 @@ def __init__(self, model, X, y=None):
202203
"The onnx package must be installed to work with ONNX models. Please `pip install onnx`."
203204
)
204205

205-
# TODO: size of X should match size of graph.input
206-
207206
self._model = model
208207
self._X = X
209208
self._y = y
@@ -214,10 +213,14 @@ def __init__(self, model, X, y=None):
214213
outputs = [self._tensor_to_dataframe(o) for o in inferred_model.graph.output]
215214

216215
if len(inputs) > 1:
217-
pass # TODO: warn that only first input will be captured
216+
warnings.warn(
217+
f"The ONNX model has {len(inputs)} inputs but only the first input will be captured in Model Manager."
218+
)
218219

219220
if len(outputs) > 1:
220-
pass # TODO: warn that only the first output will be captured
221+
warnings.warn(
222+
f"The ONNX model has {len(outputs)} outputs but only the first input will be captured in Model Manager."
223+
)
221224

222225
self._X_df = inputs[0]
223226
self._y_df = outputs[0]
@@ -263,7 +266,7 @@ def _tensor_to_dataframe(tensor):
263266
"""
264267
if isinstance(tensor, onnx.onnx_ml_pb2.ValueInfoProto):
265268
tensor = json_format.MessageToDict(tensor)
266-
elif not isinstance(tensor, dict):
269+
elif not isinstance(tensor, dict):
267270
raise ValueError(f"Unexpected type {type(tensor)}.")
268271

269272
name = tensor.get("name", "Var")
@@ -275,9 +278,13 @@ def _tensor_to_dataframe(tensor):
275278
dtype = onnx.helper.tensor_dtype_to_np_dtype(type_["tensorType"]["elemType"])
276279

277280
# Tuple of tensor dimensions e.g. (1, 1, 24)
278-
input_dims = tuple(int(d["dimValue"]) for d in type_["tensorType"]["shape"]["dim"])
281+
input_dims = tuple(
282+
int(d["dimValue"]) for d in type_["tensorType"]["shape"]["dim"]
283+
)
279284

280-
return pd.DataFrame(dtype=dtype, columns=[f"{name}{i+1}" for i in range(math.prod(input_dims))])
285+
return pd.DataFrame(
286+
dtype=dtype, columns=[f"{name}{i+1}" for i in range(math.prod(input_dims))]
287+
)
281288

282289
@property
283290
def algorithm(self) -> str:
@@ -309,7 +316,16 @@ def model(self) -> object:
309316

310317
@property
311318
def model_params(self) -> Dict[str, Any]:
312-
return {k: getattr(self.model, k, None) for k in ("ir_version", "model_version", "opset_import", "producer_name", "producer_version")}
319+
return {
320+
k: getattr(self.model, k, None)
321+
for k in (
322+
"ir_version",
323+
"model_version",
324+
"opset_import",
325+
"producer_name",
326+
"producer_version",
327+
)
328+
}
313329

314330
@property
315331
def predict_function(self) -> Callable:

tests/unit/test_model_info_onnx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# classification/regression/etc
1414
#
1515

16+
1617
@pytest.fixture
1718
def mnist_model(tmp_path):
1819
class Net(torch.nn.Module):
@@ -52,4 +53,4 @@ def test_get_info(mnist_model):
5253
assert info.is_classifier
5354
assert not info.is_binary_classifier
5455
assert not info.is_regressor
55-
assert not info.is_clusterer
56+
assert not info.is_clusterer

0 commit comments

Comments
 (0)