Skip to content

Commit 1fb479b

Browse files
authored
bugfix: correctly handle regression models with nominal inputs (#103)
* bugfix: correctly identify classification vs regression models * code qual improvements
1 parent ad10531 commit 1fb479b

6 files changed

+6325
-7
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
Unreleased
22
----------
3-
-
3+
**Bugfixes**
4+
- Fixed an issue with `register_model()` where random forest, gradient boosting, and SVM regression models with
5+
nominal inputs where incorrectly treated as classification models.
46

57
v1.6.1 (2021-09-01)
68
-------------------

src/sasctl/utils/astore.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -350,12 +350,19 @@ def _get_model_properties(result):
350350
else:
351351
algorithm = None
352352

353+
def is_classification(r):
354+
"""Determine if the ASTORE model describes a classification model."""
355+
return classification_target(r) is not None
356+
353357
def classification_target(r):
354-
target = r.OutputVariables.Name.str.startswith('I_')
355-
target = r.OutputVariables.Name[target].iloc[0]
356-
return target.replace('I_', '', 1)
358+
"""Get the name of the classification target variable."""
359+
target = r.OutputVariables.Name[r.OutputVariables.Name.str.startswith('I_')]
360+
if target.shape[0] > 0:
361+
return target.iloc[0].replace('I_', '', 1)
362+
return None
357363

358364
def regression_target(r):
365+
"""Get the name of the regression target variable."""
359366
target = r.OutputVariables.Name.str.startswith('P_')
360367
target = r.OutputVariables.Name[target].iloc[0]
361368
return target.replace('P_', '', 1)
@@ -375,7 +382,7 @@ def regression_target(r):
375382
elif algorithm == 'forest':
376383
properties['algorithm'] = 'Random forest'
377384

378-
if 'Classification' in result.InputVariables.Type.values:
385+
if is_classification(result):
379386
properties['function'] = 'classification'
380387
properties['targetVariable'] = classification_target(result)
381388
else:
@@ -385,7 +392,7 @@ def regression_target(r):
385392
elif algorithm == 'gradboost':
386393
properties['algorithm'] = 'Gradient boosting'
387394

388-
if 'Classification' in result.InputVariables.Type.values:
395+
if is_classification(result):
389396
properties['function'] = 'classification'
390397
properties['targetVariable'] = classification_target(result)
391398

@@ -398,7 +405,7 @@ def regression_target(r):
398405
elif algorithm == 'svmachine':
399406
properties['algorithm'] = 'Support vector machine'
400407

401-
if 'Classification' in result.InputVariables.Type.values:
408+
if is_classification(result):
402409
properties['function'] = 'classification'
403410
properties['targetVariable'] = classification_target(result)
404411
properties['targetLevel'] = 'binary'

tests/cassettes/tests.integration.test_astore_models.test_forest_regression_with_nominals_swat.json

Lines changed: 2069 additions & 0 deletions
Large diffs are not rendered by default.

tests/cassettes/tests.integration.test_astore_models.test_gradboost_regression_with_nominals_swat.json

Lines changed: 2075 additions & 0 deletions
Large diffs are not rendered by default.

tests/cassettes/tests.integration.test_astore_models.test_svm_regression_with_nominals_swat.json

Lines changed: 2075 additions & 0 deletions
Large diffs are not rendered by default.

tests/integration/test_astore_models.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,37 @@ def test_forest_regression(cas_session, boston_dataset):
215215
check_input_variables(files, BOSTON_INPUT_VARS)
216216

217217

218+
def test_forest_regression_with_nominals(cas_session, boston_dataset):
219+
target = {
220+
'tool': 'SAS Visual Data Mining and Machine Learning',
221+
'targetVariable': 'Price',
222+
'scoreCodeType': 'ds2MultiType',
223+
'function': 'prediction',
224+
'algorithm': 'Random forest',
225+
}
226+
227+
cas_session.loadactionset('decisiontree')
228+
cas_session.loadactionset('astore')
229+
230+
tbl = cas_session.upload(boston_dataset).casTable
231+
232+
tbl.decisiontree.foresttrain(
233+
target='Price',
234+
inputs=list(boston_dataset.columns[:-1]),
235+
nominals=['chas'],
236+
saveState='astore',
237+
)
238+
239+
desc = cas_session.astore.describe(rstore='astore', epcode=True)
240+
props = _get_model_properties(desc)
241+
242+
for k, v in target.items():
243+
assert props[k] == v
244+
245+
files = create_files_from_astore(cas_session.CASTable('astore'))
246+
check_input_variables(files, BOSTON_INPUT_VARS)
247+
248+
218249
def test_gradboost_binary_classification(cas_session, cancer_dataset):
219250
target = {
220251
'tool': 'SAS Visual Data Mining and Machine Learning',
@@ -300,6 +331,37 @@ def test_gradboost_regression(cas_session, boston_dataset):
300331
check_input_variables(files, BOSTON_INPUT_VARS)
301332

302333

334+
def test_gradboost_regression_with_nominals(cas_session, boston_dataset):
335+
target = {
336+
'tool': 'SAS Visual Data Mining and Machine Learning',
337+
'targetVariable': 'Price',
338+
'scoreCodeType': 'ds2MultiType',
339+
'function': 'prediction',
340+
'algorithm': 'Gradient boosting',
341+
}
342+
343+
cas_session.loadactionset('decisiontree')
344+
cas_session.loadactionset('astore')
345+
346+
tbl = cas_session.upload(boston_dataset).casTable
347+
348+
tbl.decisiontree.gbtreetrain(
349+
target='Price',
350+
inputs=list(boston_dataset.columns[:-1]),
351+
nominals=['chas'],
352+
savestate='astore',
353+
)
354+
355+
desc = cas_session.astore.describe(rstore='astore', epcode=True)
356+
props = _get_model_properties(desc)
357+
358+
for k, v in target.items():
359+
assert props[k] == v
360+
361+
files = create_files_from_astore(cas_session.CASTable('astore'))
362+
check_input_variables(files, BOSTON_INPUT_VARS)
363+
364+
303365
def test_neuralnet_regression(cas_session, boston_dataset):
304366
target = {
305367
'tool': 'SAS Visual Data Mining and Machine Learning',
@@ -393,6 +455,34 @@ def test_svm_regression(cas_session, boston_dataset):
393455
check_input_variables(files, BOSTON_INPUT_VARS)
394456

395457

458+
def test_svm_regression_with_nominals(cas_session, boston_dataset):
459+
target = {
460+
'tool': 'SAS Visual Data Mining and Machine Learning',
461+
'targetVariable': 'Price',
462+
'scoreCodeType': 'ds2MultiType',
463+
'function': 'prediction',
464+
'algorithm': 'Support vector machine',
465+
}
466+
467+
cas_session.loadactionset('svm')
468+
cas_session.loadactionset('astore')
469+
470+
tbl = cas_session.upload(boston_dataset).casTable
471+
472+
tbl.svm.svmTrain(
473+
target='Price', inputs=list(boston_dataset.columns[:-1]), nominals=['chas'], saveState='astore'
474+
)
475+
476+
desc = cas_session.astore.describe(rstore='astore', epcode=True)
477+
props = _get_model_properties(desc)
478+
479+
for k, v in target.items():
480+
assert props[k] == v
481+
482+
files = create_files_from_astore(cas_session.CASTable('astore'))
483+
check_input_variables(files, BOSTON_INPUT_VARS)
484+
485+
396486
def test_bayesnet_binary_classification(cas_session, cancer_dataset):
397487
target = {
398488
'tool': 'SAS Visual Data Mining and Machine Learning',

0 commit comments

Comments
 (0)