Skip to content

Commit 856afef

Browse files
committed
Update add_model_content to current API specs; include dict uploads
1 parent 4a3da74 commit 856afef

File tree

1 file changed

+36
-25
lines changed

1 file changed

+36
-25
lines changed

src/sasctl/_services/model_repository.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
"""The Model Repository service supports registering and managing models."""
88

99
from warnings import warn
10+
from io import StringIO
11+
import json
1012

1113
from .service import Service
1214
from ..core import current_session, get, delete, sasctl_command, HTTPError
@@ -366,22 +368,24 @@ def create_model(
366368
)
367369

368370
@classmethod
369-
def add_model_content(cls, model, file, name, role=None, content_type=None):
370-
"""Add additional files to the model.
371+
def add_model_content(cls, model, file, name, content_type="multipart/form-data", role=None):
372+
"""Add additional files to the model. Additional files can come in the form of
373+
a bytes-like, string, or dict object. String and dict objects will be converted to
374+
a bytes-like object for upload.
371375
372376
Parameters
373377
----------
374378
model : str or dict
375379
The name or id of the model, or a dictionary representation of
376380
the model.
377-
file : str or bytes
381+
file : str, dict, or bytes
378382
A file related to the model, such as the model code.
379383
name : str
380384
Name of the file related to the model.
381-
role : str
382-
Role of the model file, such as 'Python pickle'.
383-
content_type : str
384-
an HTTP Content-Type value
385+
content_type : str, optional
386+
An HTTP Content-Type value. Default value is multipart/form-data.
387+
role : str, optional
388+
Role of the model file, such as 'Python pickle'. Default value is None.
385389
386390
Returns
387391
-------
@@ -396,38 +400,45 @@ def add_model_content(cls, model, file, name, role=None, content_type=None):
396400
else:
397401
model = cls.get_model(model)
398402
id_ = model["id"]
399-
400-
if content_type is None and isinstance(file, bytes):
401-
content_type = "application/octet-stream"
402-
403-
if content_type is not None:
404-
files = {name: (name, file, content_type)}
403+
404+
# Convert string file representations to bytes-like object
405+
if isinstance(file, str):
406+
file = StringIO(file)
407+
content_type="multipart/form-data"
408+
# Convert dict file representations to bytes-like object
409+
elif isinstance(file, dict):
410+
file = StringIO(json.dumps(file))
411+
content_type="multipart/form-data"
412+
413+
files = {"files": (name, file, content_type)}
414+
415+
if role is None:
416+
params = {}
405417
else:
406-
files = {name: file}
407-
408-
metadata = {"role": role, "name": name}
409-
410-
# return cls.post('/models/{}/contents'.format(id_), files=files, data=metadata)
418+
params = {"role": role}
419+
params = "&".join("{}={}".format(k, v) for k, v in params.items())
411420

412-
# if the file already exists, a 409 error will be returned
421+
# If the file already exists, a 409 error will be returned
413422
try:
414423
return cls.post(
415-
"/models/{}/contents".format(id_), files=files, data=metadata
424+
"/models/{}/contents".format(id_),
425+
files=files,
426+
params=params,
416427
)
417-
# delete the older duplicate model and rerun the API call
418-
except HTTPError as e:
419-
if e.code == 409:
428+
except HTTPError as err:
429+
# Delete the older duplicate model file and rerun the API call
430+
if err.code == 409:
420431
model_contents = cls.get_model_contents(id_)
421432
for item in model_contents:
422433
if item.name == name:
423434
cls.delete("/models/{}/contents/{}".format(id_, item.id))
424435
return cls.post(
425436
"/models/{}/contents".format(id_),
426437
files=files,
427-
data=metadata,
438+
params=params,
428439
)
429440
else:
430-
raise e
441+
raise err
431442

432443
@classmethod
433444
def default_repository(cls):

0 commit comments

Comments
 (0)