Skip to content

Commit d55c12e

Browse files
committed
draft updates for model version
1 parent 68d2b2f commit d55c12e

File tree

2 files changed

+62
-6
lines changed

2 files changed

+62
-6
lines changed

src/sasctl/_services/model_repository.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ def get_model_contents(cls, model):
126126
@classmethod
127127
def create_model(cls, model, project, description=None, modeler=None,
128128
function=None, algorithm=None, tool=None,
129-
is_champion=False, properties={}, **kwargs):
129+
is_champion=False, properties={},
130+
version=None,
131+
**kwargs):
130132
"""Creates a model into a project or folder.
131133
132134
Parameters
@@ -177,20 +179,31 @@ def create_model(cls, model, project, description=None, modeler=None,
177179
The display name for the model version.
178180
properties : array_like, optional (custom properties)
179181
Custom model properties that can be set: name, value, type
180-
181182
inputVariables : array_like, optional
182183
Model input variables. By default, these are the same as the model
183184
project.
184185
outputVariables : array_like, optional
185186
Model output variables. By default, these are the same as the model
186187
project.
188+
version : str or int
189+
Whether to create a new version of the model or update an
190+
existing version. May be a specific numbered version to
191+
replace, 'latest' to update the most recent version, or 'new' to
192+
add a new version. Defaults to 'new'
187193
188194
Returns
189195
-------
190196
str
191197
The model schema returned in JSON format.
192198
193199
"""
200+
version = version or 'new'
201+
202+
# Check if the model already exists
203+
model_obj = cls.get_model(model)
204+
205+
is_new_model = model_obj == None
206+
194207
if isinstance(model, str):
195208
model = {'name': model}
196209

@@ -209,15 +222,16 @@ def create_model(cls, model, project, description=None, modeler=None,
209222
model['tool'] = tool or model.get('tool')
210223
model['champion'] = is_champion or model.get('champion')
211224
model['role'] = 'Champion' if model.get('champion',
212-
False) else 'Challenger'
225+
False) else None
213226
model['description'] = description or model.get('description')
214227
model.setdefault('properties', [{'name': k, 'value': v} for k, v in
215228
properties.items()])
216229

217230
# TODO: add kwargs (pop)
218231
# model.update(kwargs)
219-
return cls.post('/models', json=model, headers={
220-
'Content-Type': 'application/vnd.sas.models.model+json'})
232+
if is_new_model:
233+
return cls.post('/models', json=model, headers={
234+
'Content-Type': 'application/vnd.sas.models.model+json'})
221235

222236
@classmethod
223237
def add_model_content(cls, model, file, name=None, role=None):

src/sasctl/tasks.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def register_model(model, name, project, repository=None, input=None,
7575
input
7676
version : {'new', 'latest', int}, optional
7777
Version number of the project in which the model should be created.
78-
files :
78+
Defaults to 'new'.
79+
files : list
7980
force : bool, optional
8081
Create dependencies such as projects and repositories if they do not
8182
already exist.
@@ -100,6 +101,46 @@ def register_model(model, name, project, repository=None, input=None,
100101
# TODO: Allow file info to be specified
101102
# TODO: Performance stats
102103

104+
# If version not specified, default to creating a new version
105+
version = version or 'new'
106+
107+
# If replacing an existing version, make sure the model version exists
108+
if str(version).lower() != 'new':
109+
model_obj = mr.get_model(name)
110+
if model_obj is None:
111+
raise ValueError("Unable to update version '%s' of model '%s%. "
112+
"Model not found." % (version, name))
113+
model_versions = request_link(model_obj, 'modelVersions')
114+
assert isinstance(model_versions, list)
115+
116+
# Use 'new' to create a new version if one doesn't exist yet.
117+
if len(model_versions) == 0:
118+
raise ValueError("No existing version of model '%s' to update."
119+
% name)
120+
121+
# Help function for extracting version number of REST response
122+
def get_version(x):
123+
return float(x.get('modelVersionName', 0))
124+
125+
if str(version).isnumeric():
126+
match = [x for x in model_versions if float(version) ==
127+
get_version(x)]
128+
assert len(match) <= 1
129+
130+
match = match[0] if len(match) else None
131+
elif str(version).lower() == 'latest':
132+
# Sort by version number and get first
133+
match = sorted(model_versions, key=get_version)[0]
134+
else:
135+
raise ValueError("Unrecognized version '%s'." % version)
136+
137+
138+
139+
# TODO: get ID of correct model version
140+
# if version != new, get existing model
141+
# get model (modelVersions) rel
142+
# -> returns list w/ id, modelVersionName, etc
143+
103144
files = files or []
104145

105146
# Find the project if it already exists
@@ -111,6 +152,7 @@ def register_model(model, name, project, repository=None, input=None,
111152
if p is None and not create_project:
112153
raise ValueError("Project '{}' not found".format(project))
113154

155+
# Use default repository if not specified
114156
if repository is None:
115157
repository = mr.default_repository()
116158
else:

0 commit comments

Comments
 (0)