Skip to content

Commit 578db1b

Browse files
committed
First pass of tasks refactor
1 parent 1edd868 commit 578db1b

File tree

7 files changed

+65
-59
lines changed

7 files changed

+65
-59
lines changed

examples/full_lifecycle.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
# Register the model in SAS Model Manager
3636
register_model(lm,
3737
model_name,
38-
input=X_train, # Use X to determine model inputs
39-
project=project, # Register in "Iris" project
38+
input_data=X_train, # Use X to determine model inputs
39+
project=project, # Register in "Iris" project
4040
force=True) # Create project if it doesn't exist
4141

4242
# Update project properties. Target variable must be set before performance
@@ -60,7 +60,7 @@
6060
dt.fit(X_train, y_train)
6161

6262
# Register the second model in Model Manager
63-
model_dt = register_model(dt, 'Decision Tree', project, input=X)
63+
model_dt = register_model(dt, 'Decision Tree', project, input_data=X)
6464

6565
# Publish from Model Manager -> MAS
6666
module_dt = publish_model(model_dt, 'maslocal')

src/sasctl/tasks.py

Lines changed: 43 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,10 @@
4545
_PROP_NAME_MAXLEN = 60
4646

4747

48-
49-
5048
def _property(k, v):
5149
return {"name": str(k)[:_PROP_NAME_MAXLEN], "value": str(v)[:_PROP_VALUE_MAXLEN]}
5250

5351

54-
5552
def _sklearn_to_dict(model):
5653
# Convert Scikit-learn values to built-in Model Manager values
5754
mappings = {
@@ -106,7 +103,9 @@ def _sklearn_to_dict(model):
106103
return result
107104

108105

109-
def _register_sklearn_40(model, model_name, project_name, input_data, output_data, overwrite=False):
106+
def _register_sklearn_40(
107+
model, model_name, project_name, input_data, output_data, overwrite=False
108+
):
110109
model_info = get_model_info(model, input_data, output_data)
111110

112111
# TODO: allow passing description in register_model()
@@ -121,22 +120,16 @@ def _register_sklearn_40(model, model_name, project_name, input_data, output_dat
121120
files.update(pzmm.JSONFiles.write_var_json(input_data))
122121
files.update(pzmm.JSONFiles.write_var_json(output_data, is_input=False))
123122

124-
if model_info.is_binary_classifier:
125-
num_categories = 2
126-
elif model_info.is_classifier:
127-
num_categories = len(model_info.target_values)
128-
else:
129-
num_categories = 0
130-
131-
files.update(pzmm.JSONFiles.write_model_properties_json(model_name,
132-
target_variable=model_info.output_column_names,
133-
target_event=model_info.target_values,
134-
num_target_categories=num_categories,
135-
event_prob_var=None,
136-
model_desc=model_info.description[:_DESC_MAXLEN],
137-
model_function=model_info.analytic_function,
138-
model_type=model_info.algorithm
139-
))
123+
files.update(
124+
pzmm.JSONFiles.write_model_properties_json(
125+
model_name,
126+
target_variable=model_info.output_column_names,
127+
target_values=model_info.target_values,
128+
model_desc=model_info.description[:_DESC_MAXLEN],
129+
model_function=model_info.analytic_function,
130+
model_algorithm=model_info.algorithm,
131+
)
132+
)
140133
"""
141134
target_variable : string
142135
Target variable to be predicted by the model.
@@ -151,15 +144,16 @@ def _register_sklearn_40(model, model_name, project_name, input_data, output_dat
151144
files.update(pzmm.JSONFiles.write_file_metadata_json(model_name))
152145

153146
# TODO: How to determine if should call .predict() or .predict_proba()? Base on output data?
154-
pzmm.ImportModel.import_model(model_files=files,
155-
model_prefix=model_name,
156-
project=project_name,
157-
predict_method=model.predict,
158-
input_data=input_data,
159-
output_variables=[],
160-
score_cas=True,
161-
missing_values=False # assuming Pipeline will be used for imputing.
162-
)
147+
pzmm.ImportModel.import_model(
148+
model_files=files,
149+
model_prefix=model_name,
150+
project=project_name,
151+
predict_method=model.predict,
152+
input_data=input_data,
153+
output_variables=[],
154+
score_cas=True,
155+
missing_values=False, # assuming Pipeline will be used for imputing.
156+
)
163157

164158

165159
def _create_project(project_name, model, repo, input_vars=None, output_vars=None):
@@ -235,7 +229,7 @@ def register_model(
235229
name,
236230
project,
237231
repository=None,
238-
input=None,
232+
input_data=None,
239233
version=None,
240234
files=None,
241235
force=False,
@@ -246,11 +240,11 @@ def register_model(
246240
Parameters
247241
----------
248242
model : swat.CASTable or sklearn.BaseEstimator
249-
The model to register. If an instance of ``swat.CASTable`` the table
250-
is assumed to hold an ASTORE, which will be downloaded and used to
251-
construct the model to register. If a scikit-learn estimator, the
252-
model will be pickled and uploaded to the registry and score code will
253-
be generated for publishing the model to MAS.
243+
The model to register. If an instance of ``swat.CASTable`` the table is assumed
244+
to hold an ASTORE, which will be downloaded and used to construct the model to
245+
register. If a scikit-learn estimator, the model will be pickled and uploaded
246+
to the registry and score code will be generated for publishing the model to
247+
CAS or MAS.
254248
name : str
255249
Designated name for the model in the repository.
256250
project : str or dict
@@ -259,14 +253,14 @@ def register_model(
259253
repository : str or dict, optional
260254
The name or id of the repository, or a dictionary representation of
261255
the repository. If omitted, the default repository will be used.
262-
input : DataFrame, type, list of type, or dict of str: type, optional
256+
input_data : DataFrame, type, list of type, or dict of str: type, optional
263257
The expected type for each input value of the target function.
264258
Can be omitted if target function includes type hints. If a DataFrame
265259
is provided, the columns will be inspected to determine type
266260
information. If a single type is provided, all columns will be assumed
267261
to be that type, otherwise a list of column types or a dictionary of
268262
column_name: type may be provided.
269-
output : array-like
263+
output_data : array-like
270264
A Numpy array or Pandas DataFrame that contains
271265
version : {'new', 'latest', int}, optional
272266
Version number of the project in which the model should be created.
@@ -305,7 +299,8 @@ def register_model(
305299
Added `record_packages` parameter.
306300
307301
.. versionchanged:: v1.7.4
308-
Update ASTORE handling for ease of use and removal of SAS Viya 4 score code errors
302+
Update ASTORE handling for ease of use and removal of SAS Viya 4 score code
303+
errors
309304
310305
"""
311306
# If version not specified, default to creating a new version
@@ -320,7 +315,7 @@ def register_model(
320315
create_project = bool(p is None and force is True)
321316

322317
if p is None and not create_project:
323-
raise ValueError("Project '{}' not found".format(project))
318+
raise ValueError(f"Project '{project}' not found")
324319

325320
# Use default repository if not specified
326321
try:
@@ -331,7 +326,7 @@ def register_model(
331326
except HTTPError as e:
332327
if e.code == 403:
333328
raise AuthorizationError(
334-
"Unable to register model. User account does not have read permissions "
329+
"Unable to register model. User account does not have read permissions "
335330
"for the /modelRepository/repositories/ URL. Please contact your SAS "
336331
"Viya administrator."
337332
)
@@ -342,9 +337,9 @@ def register_model(
342337
raise ValueError("Unable to find a default repository")
343338

344339
if repo_obj is None:
345-
raise ValueError("Unable to find repository '{}'".format(repository))
340+
raise ValueError(f"Unable to find repository '{repository}'")
346341

347-
# If model is a CASTable then assume it holds an ASTORE model. Import these via a ZIP file.
342+
# If model is a CASTable then assume it holds an ASTORE model; import with zip file
348343
if "swat.cas.table.CASTable" in str(type(model)):
349344
if swat is None:
350345
raise RuntimeError(
@@ -357,7 +352,7 @@ def register_model(
357352
)
358353

359354
if "DataStepSrc" in model.columns:
360-
zip_file = utils.create_package_from_datastep(model, input=input)
355+
zip_file = utils.create_package_from_datastep(model, input=input_data)
361356
if create_project:
362357
out_var = []
363358
in_var = []
@@ -427,7 +422,7 @@ def register_model(
427422

428423
if current_session().version_info() < 4:
429424
# Upload the model as a ZIP file if using Viya 3.
430-
zipfile = utils.create_package(model, input=input)
425+
zipfile = utils.create_package(model, input=input_data)
431426
model = mr.import_model_from_zip(
432427
name, project, zipfile, version=version
433428
)
@@ -456,17 +451,17 @@ def register_model(
456451
# If the model is a scikit-learn model, generate the model dictionary
457452
# from it and pickle the model for storage
458453
if all(hasattr(model, attr) for attr in ["_estimator_type", "get_params"]):
459-
460454
# Pickle the model so we can store it
461455
model_pkl = pickle.dumps(model)
462-
files.append({"name": "model.pkl", "file": model_pkl, "role": "Python Pickle"})
456+
files.append({"name": "model.pkl", "file": model_pkl, "role": "Python pickle"})
463457

464458
target_funcs = [f for f in ("predict", "predict_proba") if hasattr(model, f)]
465459

466460
# Extract model properties
467461
model = _sklearn_to_dict(model)
468462
model["name"] = name
469463

464+
# TODO: Swap for pzmm.JSONFiles.create_requirements_json()
470465
# Get package versions in environment
471466
packages = installed_packages()
472467
if record_packages and packages is not None:
@@ -485,10 +480,11 @@ def register_model(
485480
# Generate and upload a requirements.txt file
486481
files.append({"name": "requirements.txt", "file": "\n".join(packages)})
487482

483+
# TODO: Swap for pzmm.ScoreCode.write_score_code()
488484
# Generate PyMAS wrapper
489485
try:
490486
mas_module = from_pickle(
491-
model_pkl, target_funcs, input_types=input, array_input=True
487+
model_pkl, target_funcs, input_types=input_data, array_input=True
492488
)
493489

494490
# Include score code files from ESP and MAS

tests/integration/test_full_pipelines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def test_register_model(self, boston_dataset):
148148
model.fit(X, y)
149149

150150
model = register_model(
151-
model, self.MODEL_NAME, self.PROJECT_NAME, input=X, force=True
151+
model, self.MODEL_NAME, self.PROJECT_NAME, input_data=X, force=True
152152
)
153153
assert model.name == self.MODEL_NAME
154154
assert model.projectName == self.PROJECT_NAME
@@ -266,7 +266,7 @@ def test_register_model(self, iris_dataset):
266266
model.fit(X, y)
267267

268268
model = register_model(
269-
model, self.MODEL_NAME, self.PROJECT_NAME, input=X, force=True
269+
model, self.MODEL_NAME, self.PROJECT_NAME, input_data=X, force=True
270270
)
271271
assert model.name == self.MODEL_NAME
272272
assert model.projectName == self.PROJECT_NAME

tests/integration/test_tasks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_register_sklearn(self, sklearn_logistic_model):
9595
sk_model,
9696
SCIKIT_MODEL_NAME,
9797
project=PROJECT_NAME,
98-
input=train_df,
98+
input_data=train_df,
9999
force=True,
100100
)
101101
assert isinstance(model, RestObj)
@@ -196,7 +196,7 @@ def test_register_model(self, sklearn_linear_model):
196196

197197
# Register model and ensure attributes are set correctly
198198
model = register_model(
199-
sk_model, self.MODEL_NAME, project=self.PROJECT_NAME, input=X, force=True
199+
sk_model, self.MODEL_NAME, project=self.PROJECT_NAME, input_data=X, force=True
200200
)
201201

202202
assert isinstance(model, RestObj)

tests/scenarios/test_project_with_sas_and_sklearn_classification_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test(cas_session, iris_dataset):
6767
sk_model.fit(X, y)
6868

6969
sas_model = register_model(astore, SAS_MODEL_NAME, PROJECT_NAME, force=True)
70-
sk_model = register_model(sk_model, SCIKIT_MODEL_NAME, PROJECT_NAME, input=X)
70+
sk_model = register_model(sk_model, SCIKIT_MODEL_NAME, PROJECT_NAME, input_data=X)
7171

7272
# Publish to MAS
7373
sas_module = publish_model(sas_model, "maslocal", replace=True)

tests/scenarios/test_project_with_sas_and_sklearn_regression_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test(cas_session, boston_dataset):
6666
sk_model.fit(X, y)
6767

6868
sas_model = register_model(astore, SAS_MODEL_NAME, PROJECT_NAME, force=True)
69-
sk_model = register_model(sk_model, SCIKIT_MODEL_NAME, PROJECT_NAME, input=X)
69+
sk_model = register_model(sk_model, SCIKIT_MODEL_NAME, PROJECT_NAME, input_data=X)
7070

7171
# Test overwriting model content
7272
mr.add_model_content(sk_model, "Your mother was a hamster!", "insult.txt")

tests/unit/test_tasks.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,29 @@ def test_register_sklearn_with_pzmm(iris_dataset):
8787
# Verify that expected files were generated.
8888
files = kwargs["model_files"]
8989
assert isinstance(files, dict)
90-
assert all(f in files for f in (f"{MODEL_NAME}.pickle", "inputVar.json", "outputVar.json", "ModelProperties.json", "fileMetadata.json"))
90+
assert all(
91+
f in files
92+
for f in (
93+
f"{MODEL_NAME}.pickle",
94+
"inputVar.json",
95+
"outputVar.json",
96+
"ModelProperties.json",
97+
"fileMetadata.json",
98+
)
99+
)
91100

92101
assert kwargs["model_prefix"] == MODEL_NAME
93102
assert kwargs["project"] == PROJECT_NAME
94103
assert kwargs["predict_method"] == model.predict
95104
assert kwargs["output_variables"]
96-
assert kwargs["score_cas"] == True
97-
assert kwargs["missing_values"] == False
105+
assert kwargs["score_cas"] is True
106+
assert kwargs["missing_values"] is False
98107

99108
pd.testing.assert_frame_equal(kwargs["input_data"], X)
100109

101110
pytest.fail("Verify import_model inputs are correct")
102111

112+
103113
"""
104114
metrics : string list
105115
The scoring metrics for the model. For classification models, it is assumed

0 commit comments

Comments
 (0)