Skip to content

Commit 83fe479

Browse files
committed
Revert task refactor changes to side branch and run black/isort
1 parent 7610b3d commit 83fe479

File tree

9 files changed

+22
-26
lines changed

9 files changed

+22
-26
lines changed

examples/register_scikit_classification_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# Register the model in Model Manager
2626
register_model(model,
2727
model_name,
28-
input_data=X, # Use X to determine model inputs
28+
input=X, # Use X to determine model inputs
2929
project='Iris', # Register in "Iris" project
3030
force=True) # Create project if it doesn't exist
3131

examples/register_scikit_regression_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
project_name = 'Boston Housing'
2929

3030
# Register the model in SAS Model Manager
31-
register_model(model, model_name, project_name, input_data=X, force=True)
31+
register_model(model, model_name, project_name, input=X, force=True)
3232

3333
# Publish the model to the real-time scoring engine
3434
module = publish_model(model_name, 'maslocal', replace=True)

src/sasctl/_services/workflow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def run_workflow_definition(cls, name, prompts=None):
126126
headers={"Content-Type": "application/vnd.sas.workflow.variables+json"},
127127
)
128128
if isinstance(prompts, dict):
129-
130129
variables = []
131130

132131
# For each prompt defined in the workflow, check if a value was provided.

src/sasctl/tasks.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
import json
1010
import logging
1111
import math
12-
import pickle # skipcq BAN-B301
1312
import os
13+
import pickle # skipcq BAN-B301
1414
import re
1515
import sys
1616
import warnings
17+
1718
import pandas as pd
1819

1920
try:
@@ -29,9 +30,8 @@
2930
from .services import model_management as mm
3031
from .services import model_publish as mp
3132
from .services import model_repository as mr
32-
from .utils.pymas import from_pickle
3333
from .utils.misc import installed_packages
34-
34+
from .utils.pymas import from_pickle
3535

3636
logger = logging.getLogger(__name__)
3737

@@ -301,8 +301,8 @@ def register_model(
301301
if create_project:
302302
out_var = []
303303
in_var = []
304-
import zipfile as zp
305304
import copy
305+
import zipfile as zp
306306

307307
zip_file_copy = copy.deepcopy(zip_file)
308308
tmp_zip = zp.ZipFile(zip_file_copy)
@@ -765,7 +765,6 @@ def update_model_performance(data, model, label, refresh=True):
765765

766766
# Upload the performance data to CAS
767767
with sess.as_swat(server=cas_id) as s:
768-
769768
s.setsessopt(messagelevel="warning")
770769

771770
with swat.options(exception_on_severity=2):
@@ -863,9 +862,10 @@ def get_project_kpis(
863862
A pandas DataFrame representing the MM_STD_KPI table. Note that SAS
864863
missing values are replaced with pandas valid missing values.
865864
"""
866-
from .core import is_uuid
867865
from distutils.version import StrictVersion
868866

867+
from .core import is_uuid
868+
869869
# Check the pandas version for where the json_normalize function exists
870870
if pd.__version__ >= StrictVersion("1.0.3"):
871871
from pandas import json_normalize

tests/integration/test_full_pipelines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def test_register_model(self, boston_dataset):
149149
model.fit(X, y)
150150

151151
model = register_model(
152-
model, self.MODEL_NAME, self.PROJECT_NAME, input_data=X, force=True
152+
model, self.MODEL_NAME, self.PROJECT_NAME, input=X, force=True
153153
)
154154
assert model.name == self.MODEL_NAME
155155
assert model.projectName == self.PROJECT_NAME
@@ -268,7 +268,7 @@ def test_register_model(self, iris_dataset):
268268
model.fit(X, y)
269269

270270
model = register_model(
271-
model, self.MODEL_NAME, self.PROJECT_NAME, input_data=X, force=True
271+
model, self.MODEL_NAME, self.PROJECT_NAME, input=X, force=True
272272
)
273273
assert model.name == self.MODEL_NAME
274274
assert model.projectName == self.PROJECT_NAME

tests/integration/test_tasks.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_register_sklearn(self, sklearn_logistic_model):
9494
sk_model,
9595
SCIKIT_MODEL_NAME,
9696
project=PROJECT_NAME,
97-
input_data=train_df,
97+
input=train_df,
9898
force=True,
9999
)
100100
assert isinstance(model, RestObj)
@@ -195,11 +195,7 @@ def test_register_model(self, sklearn_linear_model):
195195

196196
# Register model and ensure attributes are set correctly
197197
model = register_model(
198-
sk_model,
199-
self.MODEL_NAME,
200-
project=self.PROJECT_NAME,
201-
input_data=X,
202-
force=True,
198+
sk_model, self.MODEL_NAME, project=self.PROJECT_NAME, input=X, force=True
203199
)
204200

205201
assert isinstance(model, RestObj)

tests/unit/test_pageiterator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def paging(request):
4545
def side_effect(_, link, **kwargs):
4646
assert "limit=%d" % limit in link
4747
start = int(re.search(r"(?<=start=)[\d]+", link).group())
48-
return RestObj(items=items[start: start + limit])
48+
return RestObj(items=items[start : start + limit])
4949

5050
req.side_effect = side_effect
5151
yield obj, items[:], req

tests/unit/test_tasks.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,20 @@
88

99
import pytest
1010

11-
from sasctl.core import RestObj
1211
from sasctl._services.model_repository import ModelRepository
12+
from sasctl.core import RestObj
1313

1414

1515
def test_sklearn_metadata():
1616
pytest.importorskip("sklearn")
1717

18-
from sasctl.tasks import _sklearn_to_dict
19-
from sklearn.linear_model import LogisticRegression, LinearRegression
20-
from sklearn.ensemble import GradientBoostingClassifier
21-
from sklearn.tree import DecisionTreeClassifier
22-
from sklearn.ensemble import RandomForestClassifier
18+
from sklearn.ensemble import (GradientBoostingClassifier,
19+
RandomForestClassifier)
20+
from sklearn.linear_model import LinearRegression, LogisticRegression
2321
from sklearn.svm import SVC
22+
from sklearn.tree import DecisionTreeClassifier
23+
24+
from sasctl.tasks import _sklearn_to_dict
2425

2526
info = _sklearn_to_dict(LinearRegression())
2627
assert info["algorithm"] == "Linear regression"
@@ -137,6 +138,7 @@ def test_register_model_403_error(get_project, list_repositories):
137138
See: https://github.com/sassoftware/python-sasctl/issues/39
138139
"""
139140
from urllib.error import HTTPError
141+
140142
from sasctl.exceptions import AuthorizationError
141143
from sasctl.tasks import register_model
142144

tests/unit/test_workflow.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
import pytest
1111

12-
from sasctl.core import RestObj
1312
from sasctl._services import workflow
13+
from sasctl.core import RestObj
1414

1515

1616
def test_list_workflow_prompt_invalidworkflow():
@@ -154,7 +154,6 @@ def test_run_workflow_definition_with_prompts(get_workflow, post):
154154
# Check each prompt value that was passed and ensure it was correctly
155155
# matched to the prompts defined by the workflow.
156156
for name, value in PROMPTS.items():
157-
158157
# Find the matching variable entry in the POST data
159158
variable = next(v for v in params["json"]["variables"] if v["name"] == name)
160159

0 commit comments

Comments
 (0)