Skip to content

Commit 80d4f0c

Browse files
committed
use model_info for sklearn
1 parent 993c7a5 commit 80d4f0c

File tree

5 files changed

+232
-106
lines changed

5 files changed

+232
-106
lines changed

src/sasctl/tasks.py

Lines changed: 58 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,13 @@
1414
import re
1515
import sys
1616
import warnings
17-
from tempfile import TemporaryDirectory
18-
19-
import pandas as pd
2017

2118
try:
2219
import swat
2320
except ImportError:
2421
swat = None
2522

23+
import pandas as pd
2624
from urllib.error import HTTPError
2725

2826
from . import pzmm, utils
@@ -33,7 +31,7 @@
3331
from .services import model_repository as mr
3432
from .utils.pymas import from_pickle
3533
from .utils.misc import installed_packages
36-
34+
from .utils.model_info import get_model_info
3735

3836
logger = logging.getLogger(__name__)
3937

@@ -47,10 +45,13 @@
4745
_PROP_NAME_MAXLEN = 60
4846

4947

48+
49+
5050
def _property(k, v):
5151
return {"name": str(k)[:_PROP_NAME_MAXLEN], "value": str(v)[:_PROP_VALUE_MAXLEN]}
5252

5353

54+
5455
def _sklearn_to_dict(model):
5556
# Convert Scikit-learn values to built-in Model Manager values
5657
mappings = {
@@ -68,6 +69,7 @@ def _sklearn_to_dict(model):
6869
"regressor": "prediction",
6970
}
7071

72+
# If this is a Pipeline extract the final estimator step
7173
if hasattr(model, "_final_estimator"):
7274
estimator = model._final_estimator
7375
else:
@@ -97,78 +99,67 @@ def _sklearn_to_dict(model):
9799
trainCodeType="Python",
98100
targetLevel=target_level,
99101
function=analytic_function,
100-
tool="Python %s.%s" % (sys.version_info.major, sys.version_info.minor),
102+
tool=f"Python {sys.version_info.major}.{sys.version_info.minor}",
101103
properties=[_property(k, v) for k, v in model.get_params().items()],
102104
)
103105

104106
return result
105107

106108

107-
def _register_sklearn_35():
108-
pass
109-
110-
111-
def _register_sklearn_40(model, model_name, project_name, input_data, output_data=None):
112-
113-
# TODO: if not sklearn, raise ValueError
109+
def _register_sklearn_40(model, model_name, project_name, input_data, output_data, overwrite=False):
110+
model_info = get_model_info(model, input_data, output_data)
114111

115-
model_info = _sklearn_to_dict(model)
112+
# TODO: allow passing description in register_model()
116113

117-
with TemporaryDirectory() as folder:
114+
# Will store filename: file contents as we generate files
115+
files = {}
118116

119-
# Write model to a pickle file
120-
pzmm.PickleModel.pickle_trained_model(model, model_name, folder) # generates folder/name.pickle
117+
# Write model to a pickle file
118+
files.update(pzmm.PickleModel.pickle_trained_model(model, model_name))
121119

122-
# Create a JSON file containing model input fields
123-
pzmm.JSONFiles.write_var_json(input_data, is_input=True, json_path=folder)
120+
# Create a JSON file containing model input fields
121+
files.update(pzmm.JSONFiles.write_var_json(input_data))
122+
files.update(pzmm.JSONFiles.write_var_json(output_data, is_input=False))
124123

125-
# Create a JSON file containing model output fields
126-
if output_data is not None:
127-
if model_info["function"] == "classification":
128-
output_fields = output_data.copy()
124+
if model_info.is_binary_classifier:
125+
num_categories = 2
126+
elif model_info.is_classifier:
127+
num_categories = len(model_info.target_values)
128+
else:
129+
num_categories = 0
130+
131+
files.update(pzmm.JSONFiles.write_model_properties_json(model_name,
132+
target_variable=model_info.output_column_names,
133+
target_event=model_info.target_values,
134+
num_target_categories=num_categories,
135+
event_prob_var=None,
136+
model_desc=model_info.description[:_DESC_MAXLEN],
137+
model_function=model_info.analytic_function,
138+
model_type=model_info.algorithm
139+
))
140+
"""
141+
target_variable : string
142+
Target variable to be predicted by the model.
143+
target_event : string
144+
Model target event. For example: 1 for a binary event.
145+
num_target_categories : int
146+
Number of possible target categories. For example: 2 for a binary event.
147+
event_prob_var : string, optional
148+
User-provided output event probability variable. This value should match the
149+
value in outputVar.json. Default is "P_" + target_variable + target_event.
150+
"""
151+
files.update(pzmm.JSONFiles.write_file_metadata_json(model_name))
129152

130-
if hasattr(output_fields, "columns"):
131-
output_fields.columns = ["EM_CLASSIFICATION"]
132-
else:
133-
output_fields.name = "EM_CLASSIFICATION"
134-
pzmm.JSONFiles.write_var_json(output_fields, is_input=False, json_path=folder)
135-
else:
136-
pzmm.JSONFiles.write_var_json(output_data, is_input=False, json_path=folder)
137-
# target_variable
138-
# target_event (e.g 1 for binary)
139-
# num_target_event
140-
# event_prob
141-
142-
# TODO: allow passing description in register_model()
143-
144-
pzmm.JSONFiles.write_model_properties_json(model_name,
145-
target_event=None,
146-
target_variable=None,
147-
num_target_categories=1,
148-
model_desc=model_info["description"],
149-
model_function=model_info["function"],
150-
model_type=model_info["algorithm"],
151-
json_path=folder
152-
)
153-
154-
pzmm.JSONFiles.write_file_metadata_json(model_name, json_path=folder, is_h2o_model=False)
155-
156-
predict_method = (
157-
"{}.predict_proba({})"
158-
if hasattr(model, "predict_proba")
159-
else "{}.predict({})"
160-
)
161-
predict_method = "{}.predict({})"
162-
metrics = ["EM_CLASSIFICATION"] # NOTE: only valid for classification models.
163-
pzmm.ImportModel.import_model(
164-
folder,
165-
model_name,
166-
project_name,
167-
input_data,
168-
output_data,
169-
predict_method,
170-
metrics=metrics,
171-
)
153+
# TODO: How to determine if should call .predict() or .predict_proba()? Base on output data?
154+
pzmm.ImportModel.import_model(model_files=files,
155+
model_prefix=model_name,
156+
project=project_name,
157+
predict_method=model.predict,
158+
input_data=input_data,
159+
output_variables=[],
160+
score_cas=True,
161+
missing_values=False # assuming Pipeline will be used for imputing.
162+
)
172163

173164

174165
def _create_project(project_name, model, repo, input_vars=None, output_vars=None):
@@ -275,6 +266,8 @@ def register_model(
275266
information. If a single type is provided, all columns will be assumed
276267
to be that type, otherwise a list of column types or a dictionary of
277268
column_name: type may be provided.
269+
output : array-like
270+
A Numpy array or Pandas DataFrame that contains
278271
version : {'new', 'latest', int}, optional
279272
Version number of the project in which the model should be created.
280273
Defaults to 'new'.
@@ -315,8 +308,6 @@ def register_model(
315308
Update ASTORE handling for ease of use and removal of SAS Viya 4 score code errors
316309
317310
"""
318-
# TODO: Create new version if model already exists
319-
320311
# If version not specified, default to creating a new version
321312
version = version or "new"
322313

@@ -458,6 +449,7 @@ def register_model(
458449
# If the model is a scikit-learn model, generate the model dictionary
459450
# from it and pickle the model for storage
460451
if all(hasattr(model, attr) for attr in ["_estimator_type", "get_params"]):
452+
461453
# Pickle the model so we can store it
462454
model_pkl = pickle.dumps(model)
463455
files.append({"name": "model.pkl", "file": model_pkl, "role": "Python Pickle"})

src/sasctl/utils/model_info.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class ModelInfo(ABC):
6565
observation belongs to. Returns None if not a binary classification model.
6666
6767
"""
68+
6869
@property
6970
@abstractmethod
7071
def algorithm(self) -> str:
@@ -124,6 +125,8 @@ def predict_function(self) -> Callable:
124125
@property
125126
@abstractmethod
126127
def target_values(self):
128+
# "target event"
129+
# value that indicates the target event has occurred in bianry classi
127130
return
128131

129132
@property
@@ -134,6 +137,7 @@ def threshold(self) -> Union[str, None]:
134137

135138
class SklearnModelInfo(ModelInfo):
136139
"""Stores model information for a scikit-learn model instance."""
140+
137141
# Map class names from sklearn to algorithm names used by SAS
138142
_algorithm_mappings = {
139143
"LogisticRegression": "Logistic regression",
@@ -145,7 +149,7 @@ class SklearnModelInfo(ModelInfo):
145149
"RandomForestClassifier": "Forest",
146150
"RandomForestRegressor": "Forest",
147151
"DecisionTreeClassifier": "Decision tree",
148-
"DecisionTreeRegressor": "Decision tree"
152+
"DecisionTreeRegressor": "Decision tree",
149153
}
150154

151155
def __init__(self, model, X, y):
@@ -170,7 +174,7 @@ def __init__(self, model, X, y):
170174
self._is_clusterer = is_clusterer
171175
self._model = model
172176

173-
if not hasattr(y, "columns"):
177+
if not hasattr(y, "name") and not hasattr(y, "columns"):
174178
# If example output doesn't contain column names then our DataFrame equivalent
175179
# also lacks good column names. Assign reasonable names for use downstream.
176180
if y_df.shape[1] == 1:

tests/conftest.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -519,27 +519,14 @@ def iris_dataset():
519519

520520

521521
@pytest.fixture
522-
def sklearn_iris_data():
523-
"""Returns the Iris data set as (X, y)"""
524-
pd = pytest.importorskip("pandas")
525-
sk = pytest.importorskip("sklearn.datasets")
526-
527-
raw = sk.load_iris()
528-
iris = pd.DataFrame(raw.data, columns=raw.feature_names)
529-
iris = iris.join(pd.DataFrame(raw.target))
530-
iris.columns = ["SepalLength", "SepalWidth", "PetalLength", "PetalWidth", "Species"]
531-
iris["Species"] = iris["Species"].astype("category")
532-
iris.Species.cat.categories = raw.target_names
533-
return iris.iloc[:, 0:4], iris["Species"]
534-
535-
536-
@pytest.fixture
537-
def sklearn_classification_model(sklearn_iris_data):
522+
def sklearn_classification_model(iris_dataset):
538523
"""Returns a simple scikit-learn model"""
539524
sk = pytest.importorskip("sklearn.linear_model")
540525
import warnings
541526

542-
X, y = sklearn_iris_data
527+
X = iris_dataset.drop(columns="Species")
528+
y = iris_dataset["Species"]
529+
543530
with warnings.catch_warnings():
544531
warnings.simplefilter("ignore")
545532
model = sk.LogisticRegression(multi_class="multinomial", solver="lbfgs")

0 commit comments

Comments
 (0)