Skip to content

Commit ede6691

Browse files
committed
Better handling for nominal, ordinal, and binary classification models.
1 parent 6399c12 commit ede6691

File tree

1 file changed

+40
-13
lines changed

1 file changed

+40
-13
lines changed

src/sasctl/pzmm/writeScoreCode.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def writeScoreCode(
6868
that the model files are being created from an MLFlow model.
6969
targetDF : DataFrame
7070
The `DataFrame` object contains the training data for the target variable. Note that
71-
for MLFlow models, this can set as None.
71+
for MLFlow models, this can be set as None.
7272
modelPrefix : string
7373
The variable for the model name that is used when naming model files.
7474
(For example: hmeqClassTree + [Score.py || .pickle]).
@@ -80,8 +80,16 @@ def writeScoreCode(
8080
modelFileName : string
8181
Name of the model file that contains the model.
8282
metrics : string list, optional
83-
The scoring metrics for the model. The default is a set of two
84-
metrics: EM_EVENTPROBABILITY and EM_CLASSIFICATION.
83+
The scoring metrics for the model. For classification models, it is assumed that the last value in the list
84+
represents the classification output. The default is a list of two metrics: EM_EVENTPROBABILITY and
85+
EM_CLASSIFICATION. The following scenarios are supported:
86+
1) If only one value is provided, then it is assumed that the model returns either a binary response
87+
prediction or a character output and is returned as the output.
88+
1) If only two values are provided, a threshold value needs to be set: either by providing a
89+
threshPrediction argument or the function taking the mean of the provided target column. Then the
90+
threshold value sets the classification output for the prediction.
91+
2) If more than two values are provided, the largest probability is accepted as the event and the
92+
appropriate classification value is returned for the output.
8593
pyPath : string, optional
8694
The local path of the score code file. The default is the current
8795
working directory.
@@ -462,19 +470,23 @@ def score{modelPrefix}({inputVarList}):
462470
)
463471
)
464472
if not isH2OModel and not isMLFlow:
465-
cls.pyFile.write(
466-
"""\n
473+
# TODO: Refactor arguments to better handle different classification types
474+
if len(metrics) == 1:
475+
# For models that output the classification from the prediction
476+
cls.pyFile.write(
477+
"""\n
478+
{metric} = prediction""".format(metric=metrics[0]))
479+
elif len(metrics) == 2:
480+
cls.pyFile.write(
481+
"""\n
467482
try:
468483
{metric} = float(prediction)
469484
except TypeError:
470-
# If the model expects non-binary responses, a TypeError will be raised.
471-
# The except block shifts the prediction to accept a non-binary response.
472-
{metric} = float(prediction[:,1])""".format(
473-
metric=metrics[0]
474-
)
475-
)
476-
if threshPrediction is None:
477-
threshPrediction = np.mean(targetDF)
485+
# If the prediction returns as a list of values or improper value type, a TypeError will be raised.
486+
# Attempt to handle the prediction output in the except block.
487+
{metric} = float(prediction[0])""".format(metric=metrics[0]))
488+
if threshPrediction is None:
489+
threshPrediction = np.mean(targetDF)
478490
cls.pyFile.write(
479491
"""\n
480492
if ({metric0} >= {threshold}):
@@ -486,6 +498,21 @@ def score{modelPrefix}({inputVarList}):
486498
threshold=threshPrediction,
487499
)
488500
)
501+
elif len(metrics) > 2:
502+
for i, metric in enumerate(metrics[:-1]):
503+
cls.pyFile.write(
504+
"""\
505+
{metric} = float(prediction[{i}]""".format(metric=metric, i=i)
506+
)
507+
cls.pyFile.write(
508+
"""\
509+
max_prediction = max({metric_list})
510+
index_prediction = {metric_list}.index(max_prediction)
511+
{classification} = index_prediction""".format(metric_list=metrics[:-1], classification=metrics[-1])
512+
)
513+
else:
514+
ValueError("Improper metrics argument was provided. Please provide a list of string metrics.")
515+
489516
elif isH2OModel and not isMLFlow:
490517
cls.pyFile.write(
491518
"""\n

0 commit comments

Comments
 (0)