|
18 | 18 | from six.moves.urllib.error import HTTPError
|
19 | 19 |
|
20 | 20 | from . import utils
|
21 |
| -from .core import RestObj, current_session, get, get_link, request_link |
| 21 | +from .core import RestObj, current_session, get, get_link, request_link, delete |
22 | 22 | from .exceptions import AuthorizationError
|
23 | 23 | from .services import model_management as mm
|
24 | 24 | from .services import model_publish as mp
|
@@ -136,41 +136,6 @@ def register_model(model, name, project, repository=None, input=None,
|
136 | 136 | # If version not specified, default to creating a new version
|
137 | 137 | version = version or 'new'
|
138 | 138 |
|
139 |
| - # If replacing an existing version, make sure the model version exists |
140 |
| - if str(version).lower() != 'new': |
141 |
| - model_obj = mr.get_model(name) |
142 |
| - if model_obj is None: |
143 |
| - raise ValueError("Unable to update version '%s' of model '%s%. " |
144 |
| - "Model not found." % (version, name)) |
145 |
| - model_versions = request_link(model_obj, 'modelVersions') |
146 |
| - assert isinstance(model_versions, list) |
147 |
| - |
148 |
| - # Use 'new' to create a new version if one doesn't exist yet. |
149 |
| - if len(model_versions) == 0: |
150 |
| - raise ValueError("No existing version of model '%s' to update." |
151 |
| - % name) |
152 |
| - |
153 |
| - # Help function for extracting version number of REST response |
154 |
| - def get_version(x): |
155 |
| - return float(x.get('modelVersionName', 0)) |
156 |
| - |
157 |
| - if str(version).isnumeric(): |
158 |
| - match = [x for x in model_versions if float(version) == |
159 |
| - get_version(x)] |
160 |
| - assert len(match) <= 1 |
161 |
| - |
162 |
| - match = match[0] if len(match) else None |
163 |
| - elif str(version).lower() == 'latest': |
164 |
| - # Sort by version number and get first |
165 |
| - match = sorted(model_versions, key=get_version)[0] |
166 |
| - else: |
167 |
| - raise ValueError("Unrecognized version '%s'." % version) |
168 |
| - |
169 |
| - # TODO: get ID of correct model version |
170 |
| - # if version != new, get existing model |
171 |
| - # get model (modelVersions) rel |
172 |
| - # -> returns list w/ id, modelVersionName, etc |
173 |
| - |
174 | 139 | files = files or []
|
175 | 140 |
|
176 | 141 | # Find the project if it already exists
|
@@ -340,7 +305,18 @@ def get_version(x):
|
340 | 305 | project['eventProbabilityVariable'] = prediction_variable
|
341 | 306 | mr.update_project(project)
|
342 | 307 |
|
343 |
| - model = mr.create_model(model, project) |
| 308 | + # If replacing an existing version, make sure the model version exists |
| 309 | + if str(version).lower() != 'new': |
| 310 | + #Update an existing model with new files |
| 311 | + model_obj = mr.get_model(name) |
| 312 | + if model_obj is None: |
| 313 | + raise ValueError("Unable to update version '%s' of model '%s%. " |
| 314 | + "Model not found." % (version, name)) |
| 315 | + model = mr.create_model_version(name) |
| 316 | + mr.delete_model_contents(model) |
| 317 | + else: |
| 318 | + #Assume new model to create |
| 319 | + model = mr.create_model(model, project) |
344 | 320 |
|
345 | 321 | assert isinstance(model, RestObj)
|
346 | 322 |
|
|
0 commit comments