|
4 | 4 | from pathlib import Path
|
5 | 5 | from uuid import UUID
|
6 | 6 | from warnings import warn
|
7 |
| -from typing import Union, Optional, Callable, List, Tuple |
| 7 | +from typing import Union, Optional, Callable, List, Tuple, Any |
8 | 8 |
|
9 | 9 | from pandas import DataFrame
|
10 | 10 |
|
|
14 | 14 | from .write_score_code import ScoreCode as sc
|
15 | 15 | from .zip_model import ZipModel as zm
|
16 | 16 |
|
17 |
| -# TODO: add converter for any type of dataset (list, dataframe, numpy array) |
18 |
| - |
19 | 17 |
|
20 | 18 | def project_exists(
|
21 | 19 | project: Union[str, dict, RestObj], response: Union[str, dict, RestObj, None] = None
|
@@ -138,8 +136,8 @@ def import_model(
|
138 | 136 | model_prefix: str,
|
139 | 137 | project: Union[str, dict, RestObj],
|
140 | 138 | input_data: Optional[DataFrame] = None,
|
141 |
| - predict_method: Optional[Callable[..., List]] = None, |
142 |
| - output_variables: Optional[List[str]] = None, |
| 139 | + predict_method: [Callable[..., List], List[Any]] = None, |
| 140 | + score_metrics: Optional[List[str]] = None, |
143 | 141 | pickle_type: str = "pickle",
|
144 | 142 | project_version: str = "latest",
|
145 | 143 | missing_values: bool = False,
|
@@ -186,17 +184,19 @@ def import_model(
|
186 | 184 | The `DataFrame` object contains the training data, and includes only the
|
187 | 185 | predictor columns. The write_score_code function currently supports int(64),
|
188 | 186 | float(64), and string data types for scoring. The default value is None.
|
189 |
| - predict_method : function -> list, optional |
190 |
| - The Python function used for model predictions. For example, if the model is |
191 |
| - a Scikit-Learn DecisionTreeClassifier, then pass either of the following: |
192 |
| - sklearn.tree.DecisionTreeClassifier.predict |
193 |
| - sklearn.tree.DecisionTreeClassifier.predict_proba |
| 187 | + predict_method : [function -> list, list], optional |
| 188 | + The Python function used for model predictions and the expected output |
| 189 | + types. The expected output types can be passed as example values or as the |
| 190 | + value types. For example, if the model is a Scikit-Learn |
| 191 | + DecisionTreeClassifier, then pass either of the following: |
| 192 | + * [sklearn.tree.DecisionTreeClassifier.predict, ["A"]] |
| 193 | + * [sklearn.tree.DecisionTreeClassifier.predict_proba, [0.4, float]] |
194 | 194 | The default value is None.
|
195 |
| - output_variables : string list, optional |
| 195 | + score_metrics : string list, optional |
196 | 196 | The scoring score_metrics for the model. For classification models, it is
|
197 | 197 | assumed that the first value in the list represents the classification
|
198 | 198 | output. This function supports single and multi-class classification models.
|
199 |
| - The default value is None |
| 199 | + The default value is None. |
200 | 200 | pickle_type : string, optional
|
201 | 201 | Indicator for the package used to serialize the model file to be uploaded to
|
202 | 202 | SAS Model Manager. The default value is `pickle`.
|
@@ -248,7 +248,7 @@ def import_model(
|
248 | 248 | pickle_type = mlflow_details["serialization_format"]
|
249 | 249 |
|
250 | 250 | # Import model without generating score code (SAS Viya version invariant)
|
251 |
| - if input_data is None or not predict_method or not output_variables: |
| 251 | + if input_data is None or not predict_method or not score_metrics: |
252 | 252 | warn(
|
253 | 253 | "The following arguments are required for the automatic generation of "
|
254 | 254 | "score code: input_data, predict_method, score_metrics."
|
@@ -283,7 +283,7 @@ def import_model(
|
283 | 283 | model_prefix,
|
284 | 284 | input_data,
|
285 | 285 | predict_method,
|
286 |
| - output_variables, |
| 286 | + score_metrics=score_metrics, |
287 | 287 | pickle_type=pickle_type,
|
288 | 288 | predict_threshold=predict_threshold,
|
289 | 289 | score_code_path=None if isinstance(model_files, dict) else model_files,
|
@@ -369,7 +369,7 @@ def import_model(
|
369 | 369 | model_prefix,
|
370 | 370 | input_data,
|
371 | 371 | predict_method,
|
372 |
| - output_variables, |
| 372 | + score_metrics=score_metrics, |
373 | 373 | model=model,
|
374 | 374 | pickle_type=pickle_type,
|
375 | 375 | predict_threshold=predict_threshold,
|
|
0 commit comments