Skip to content

Commit 303b371

Browse files
committed
feat: re-added model info
1 parent d42cf8e commit 303b371

File tree

2 files changed

+423
-0
lines changed

2 files changed

+423
-0
lines changed

src/sasctl/utils/model_info.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
#!/usr/bin/env python
2+
# encoding: utf-8
3+
#
4+
# Copyright © 2023, SAS Institute Inc., Cary, NC, USA. All Rights Reserved.
5+
# SPDX-License-Identifier: Apache-2.0
6+
7+
from abc import ABC, abstractmethod
8+
from typing import Any, Callable, Dict, List, Union
9+
10+
import pandas as pd
11+
12+
13+
def get_model_info(model, X, y):
14+
"""Extracts metadata about the model and associated data sets.
15+
16+
Parameters
17+
----------
18+
model : object
19+
A trained model
20+
X : array-like
21+
Sample of the data used to train the model.
22+
y : array-like
23+
Sample of the output produced by the model.
24+
25+
Returns
26+
-------
27+
ModelInfo
28+
29+
Raises
30+
------
31+
ValueError
32+
If `model` is not a recognized type.
33+
34+
"""
35+
if model.__class__.__module__.startswith("sklearn."):
36+
return SklearnModelInfo(model, X, y)
37+
38+
raise ValueError(f"Unrecognized model type {model} received.")
39+
40+
41+
class ModelInfo(ABC):
42+
"""Base class for storing model metadata.
43+
44+
Attributes
45+
----------
46+
algorithm : str
47+
analytic_function : str
48+
is_binary_classifier : bool
49+
is_classifier
50+
is_regressor
51+
is_clusterer
52+
model : object
53+
The model instance that the information was extracted from.
54+
model_params : {str: any}
55+
Dictionary of parameter names and values.
56+
output_column_names : list of str
57+
Variable names associated with the outputs of `model`.
58+
predict_function : callable
59+
The method on `model` that is called to produce predicted values.
60+
target_values : list of str
61+
Class labels returned by a classification model. For binary classification models
62+
this is just the label of the targeted event level.
63+
threshold : float or None
64+
The cutoff value used in a binary classification model to determine which class an
65+
observation belongs to. Returns None if not a binary classification model.
66+
67+
"""
68+
69+
@property
70+
@abstractmethod
71+
def algorithm(self) -> str:
72+
return
73+
74+
@property
75+
def analytic_function(self) -> str:
76+
if self.is_classifier:
77+
return "classification"
78+
if self.is_regressor:
79+
return "prediction"
80+
81+
@property
82+
def description(self) -> str:
83+
return str(self.model)
84+
85+
@property
86+
@abstractmethod
87+
def is_binary_classifier(self) -> bool:
88+
return
89+
90+
@property
91+
@abstractmethod
92+
def is_classifier(self) -> bool:
93+
return
94+
95+
@property
96+
@abstractmethod
97+
def is_clusterer(self) -> bool:
98+
return
99+
100+
@property
101+
@abstractmethod
102+
def is_regressor(self) -> bool:
103+
return
104+
105+
@property
106+
@abstractmethod
107+
def model(self) -> object:
108+
return
109+
110+
@property
111+
@abstractmethod
112+
def model_params(self) -> Dict[str, Any]:
113+
return
114+
115+
@property
116+
@abstractmethod
117+
def output_column_names(self) -> List[str]:
118+
return
119+
120+
@property
121+
@abstractmethod
122+
def predict_function(self) -> Callable:
123+
return
124+
125+
@property
126+
@abstractmethod
127+
def target_values(self):
128+
# "target event"
129+
# value that indicates the target event has occurred in bianry classi
130+
return
131+
132+
@property
133+
@abstractmethod
134+
def threshold(self) -> Union[str, None]:
135+
return
136+
137+
138+
class SklearnModelInfo(ModelInfo):
139+
"""Stores model information for a scikit-learn model instance."""
140+
141+
# Map class names from sklearn to algorithm names used by SAS
142+
_algorithm_mappings = {
143+
"LogisticRegression": "Logistic regression",
144+
"LinearRegression": "Linear regression",
145+
"SVC": "Support vector machine",
146+
"SVR": "Support vector machine",
147+
"GradientBoostingClassifier": "Gradient boosting",
148+
"GradientBoostingRegressor": "Gradient boosting",
149+
"RandomForestClassifier": "Forest",
150+
"RandomForestRegressor": "Forest",
151+
"DecisionTreeClassifier": "Decision tree",
152+
"DecisionTreeRegressor": "Decision tree",
153+
}
154+
155+
def __init__(self, model, X, y):
156+
# Ensure input/output is a DataFrame for consistency
157+
X_df = pd.DataFrame(X)
158+
y_df = pd.DataFrame(y)
159+
160+
is_classifier = hasattr(model, "classes_")
161+
is_binary_classifier = is_classifier and len(model.classes_) == 2
162+
is_clusterer = hasattr(model, "cluster_centers_")
163+
164+
# If not a classfier or a clustering algorithm and output is a single column, then
165+
# assume its a regression algorithm
166+
is_regressor = not is_classifier and not is_clusterer and y_df.shape[1] == 1
167+
168+
if not is_classifier and not is_regressor and not is_clusterer:
169+
raise ValueError(f"Unexpected model type {model} received.")
170+
171+
self._is_classifier = is_classifier
172+
self._is_binary_classifier = is_binary_classifier
173+
self._is_regressor = is_regressor
174+
self._is_clusterer = is_clusterer
175+
self._model = model
176+
177+
if not hasattr(y, "name") and not hasattr(y, "columns"):
178+
# If example output doesn't contain column names then our DataFrame equivalent
179+
# also lacks good column names. Assign reasonable names for use downstream.
180+
if y_df.shape[1] == 1:
181+
y_df.columns = ["I_Target"]
182+
elif self.is_classifier:
183+
# Output is probability of each label. Name columns according to classes.
184+
y_df.columns = [f"P_{class_}" for class_ in model.classes_]
185+
else:
186+
# This *shouldn't* happen unless a cluster algorithm somehow produces wide output.
187+
raise ValueError(f"Unrecognized model output format.")
188+
189+
# Store the data sets for reference later.
190+
self._X = X_df
191+
self._y = y_df
192+
193+
@property
194+
def algorithm(self):
195+
# Get the model or the last step in the Pipeline
196+
estimator = getattr(self.model, "_final_estimator", self.model)
197+
estimator = type(estimator).__name__
198+
199+
# Convert the class name to an algorithm, or return the class name if no match.
200+
return self._algorithm_mappings.get(estimator, estimator)
201+
202+
@property
203+
def is_binary_classifier(self):
204+
return self._is_binary_classifier
205+
206+
@property
207+
def is_classifier(self):
208+
return self._is_classifier
209+
210+
@property
211+
def is_clusterer(self):
212+
return self._is_clusterer
213+
214+
@property
215+
def is_regressor(self):
216+
return self._is_regressor
217+
218+
@property
219+
def model(self):
220+
return self._model
221+
222+
@property
223+
def model_params(self) -> Dict[str, Any]:
224+
return self.model.get_params()
225+
226+
@property
227+
def output_column_names(self):
228+
return list(self._y.columns)
229+
230+
@property
231+
def predict_function(self):
232+
# If desired output has multiple columns then we can assume its the probability values
233+
if self._y.shape[1] > 1 and hasattr(self.model, "predict_proba"):
234+
return self.model.predict_proba
235+
236+
# Otherwise its the single value from .predict()
237+
return self.model.predict
238+
239+
@property
240+
def target_values(self):
241+
if self.is_binary_classifier:
242+
return [self.model.classes_[-1]]
243+
if self.is_classifier:
244+
return list(self.model.classes_)
245+
246+
@property
247+
def threshold(self):
248+
# sklearn seems to always use 0.5 as a cutoff for .predict()
249+
if self.is_binary_classifier:
250+
return 0.5

0 commit comments

Comments
 (0)