Skip to content

Commit aaac4c3

Browse files
committed
Update import_model to match write_score_code args
1 parent 9574500 commit aaac4c3

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

src/sasctl/pzmm/import_model.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pathlib import Path
55
from uuid import UUID
66
from warnings import warn
7-
from typing import Union, Optional, Callable, List, Tuple
7+
from typing import Union, Optional, Callable, List, Tuple, Any
88

99
from pandas import DataFrame
1010

@@ -14,8 +14,6 @@
1414
from .write_score_code import ScoreCode as sc
1515
from .zip_model import ZipModel as zm
1616

17-
# TODO: add converter for any type of dataset (list, dataframe, numpy array)
18-
1917

2018
def project_exists(
2119
project: Union[str, dict, RestObj], response: Union[str, dict, RestObj, None] = None
@@ -138,8 +136,8 @@ def import_model(
138136
model_prefix: str,
139137
project: Union[str, dict, RestObj],
140138
input_data: Optional[DataFrame] = None,
141-
predict_method: Optional[Callable[..., List]] = None,
142-
output_variables: Optional[List[str]] = None,
139+
predict_method: [Callable[..., List], List[Any]] = None,
140+
score_metrics: Optional[List[str]] = None,
143141
pickle_type: str = "pickle",
144142
project_version: str = "latest",
145143
missing_values: bool = False,
@@ -186,17 +184,19 @@ def import_model(
186184
The `DataFrame` object contains the training data, and includes only the
187185
predictor columns. The write_score_code function currently supports int(64),
188186
float(64), and string data types for scoring. The default value is None.
189-
predict_method : function -> list, optional
190-
The Python function used for model predictions. For example, if the model is
191-
a Scikit-Learn DecisionTreeClassifier, then pass either of the following:
192-
sklearn.tree.DecisionTreeClassifier.predict
193-
sklearn.tree.DecisionTreeClassifier.predict_proba
187+
predict_method : [function -> list, list], optional
188+
The Python function used for model predictions and the expected output
189+
types. The expected output types can be passed as example values or as the
190+
value types. For example, if the model is a Scikit-Learn
191+
DecisionTreeClassifier, then pass either of the following:
192+
* [sklearn.tree.DecisionTreeClassifier.predict, ["A"]]
193+
* [sklearn.tree.DecisionTreeClassifier.predict_proba, [0.4, float]]
194194
The default value is None.
195-
output_variables : string list, optional
195+
score_metrics : string list, optional
196196
The scoring score_metrics for the model. For classification models, it is
197197
assumed that the first value in the list represents the classification
198198
output. This function supports single and multi-class classification models.
199-
The default value is None
199+
The default value is None.
200200
pickle_type : string, optional
201201
Indicator for the package used to serialize the model file to be uploaded to
202202
SAS Model Manager. The default value is `pickle`.
@@ -248,7 +248,7 @@ def import_model(
248248
pickle_type = mlflow_details["serialization_format"]
249249

250250
# Import model without generating score code (SAS Viya version invariant)
251-
if input_data is None or not predict_method or not output_variables:
251+
if input_data is None or not predict_method or not score_metrics:
252252
warn(
253253
"The following arguments are required for the automatic generation of "
254254
"score code: input_data, predict_method, score_metrics."
@@ -283,7 +283,7 @@ def import_model(
283283
model_prefix,
284284
input_data,
285285
predict_method,
286-
output_variables,
286+
score_metrics=score_metrics,
287287
pickle_type=pickle_type,
288288
predict_threshold=predict_threshold,
289289
score_code_path=None if isinstance(model_files, dict) else model_files,
@@ -369,7 +369,7 @@ def import_model(
369369
model_prefix,
370370
input_data,
371371
predict_method,
372-
output_variables,
372+
score_metrics=score_metrics,
373373
model=model,
374374
pickle_type=pickle_type,
375375
predict_threshold=predict_threshold,

0 commit comments

Comments
 (0)