Skip to content

Commit 0e33212

Browse files
authored
Bugfix add model content (#101)
* rename * bugfix * add test for replacing model content
1 parent 97c057d commit 0e33212

8 files changed

+669
-1473
lines changed

.github/workflows/build-test-deploy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: "Test & Build"
1+
name: "Build"
22
on: [push, workflow_dispatch]
33
jobs:
44
codeanalysis:

CHANGELOG.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
Unreleased
22
----------
3-
-
3+
**Improvements**
4+
- `model_repository.add_model_content()` will now overwrite existing files instead of failing.
5+
6+
**Bugfixes**
7+
- `PagedList.__repr__()` no longer appears to be an empty list.
48

59
v1.6.0 (2021-06-29)
610
-------------------
7-
**Improvements**
11+
**Improvements**
812
- `Session` now supports authorization using OAuth2 tokens. Use the `token=` parameter in the constructor when
913
an existing access token token is known. Alternatively, omitting the `username=` and `password=` parameters
1014
will now prompt the user for an auth code.

src/sasctl/_services/model_repository.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,23 @@ def add_model_content(cls, model, file, name, role=None, content_type=None):
406406

407407
metadata = {'role': role, 'name': name}
408408

409-
return cls.post('/models/{}/contents'.format(id_), files=files, data=metadata)
409+
# return cls.post('/models/{}/contents'.format(id_), files=files, data=metadata)
410+
411+
# if the file already exists, a 409 error will be returned
412+
try:
413+
return cls.post('/models/{}/contents'.format(id_),
414+
files=files, data=metadata)
415+
# delete the older duplicate model and rerun the API call
416+
except HTTPError as e:
417+
if e.code == 409:
418+
model_contents = cls.get_model_contents(id_)
419+
for item in model_contents:
420+
if item.name == name:
421+
cls.delete('/models/{}/contents/{}'.format(id_, item.id))
422+
return cls.post('/models/{}/contents'.format(id_),
423+
files=files, data=metadata)
424+
else:
425+
raise e
410426

411427
@classmethod
412428
def default_repository(cls):

src/sasctl/core.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,8 +1268,6 @@ def __iter__(self):
12681268
# All Iterators are also Iterables
12691269
return self
12701270

1271-
next = __next__ # Python 2 compatible
1272-
12731271
def _request_async(self, start):
12741272
"""Used by worker threads to retrieve next batch of items."""
12751273

@@ -1338,26 +1336,16 @@ def __next__(self):
13381336
if not self._cache:
13391337
self._cache = next(self._pager)
13401338

1341-
if len(self._cache) < self._pager._limit:
1342-
# number of items returned in page was less than expected
1343-
# might be last page, or items might have been filtered out by server.
1344-
pass
1345-
13461339
# Return the next item
13471340
if self._cache:
13481341
self._count -= 1
13491342
return self._cache.pop(0)
13501343

13511344
raise StopIteration()
1352-
# Out of items and out of pages
1353-
# self._count = 0
1354-
# raise StopIteration
13551345

13561346
def __iter__(self):
13571347
return self
13581348

1359-
next = __next__ # Python 2 compatible
1360-
13611349

13621350
class PagedListIterator:
13631351
"""Iterates over an instance of PagedList
@@ -1388,8 +1376,6 @@ def __next__(self):
13881376
def __iter__(self):
13891377
return self
13901378

1391-
next = __next__ # Python 2 compatibility
1392-
13931379

13941380
class PagedList(list):
13951381
"""List that dynamically loads items from the server.
@@ -1418,19 +1404,20 @@ class PagedList(list):
14181404

14191405
def __init__(self, obj, session=None, threads=4):
14201406
super(PagedList, self).__init__()
1421-
self._pager = PagedItemIterator(obj, session=session, threads=threads)
1407+
self._paged_items = PagedItemIterator(obj, session=session, threads=threads)
14221408

1423-
# Add the first page of items to the list
1424-
for _ in range(len(self._pager._cache)):
1425-
self.append(next(self._pager))
1409+
# Go ahead and add the items that were initially returned.
1410+
# Do this by "paging" so iterator remains at the correct spot.
1411+
for _ in range(len(obj['items'])):
1412+
self.append(next(self._paged_items))
14261413

14271414
# Assume that server has more items available
14281415
self._has_more = True
14291416

14301417
def __len__(self):
14311418
if self._has_more:
14321419
# Estimate the total length as items downloaded + items still on server
1433-
return super(PagedList, self).__len__() + len(self._pager)
1420+
return super(PagedList, self).__len__() + len(self._paged_items)
14341421
else:
14351422
# We've pulled everything from the server, so we have an exact length now.
14361423
return super(PagedList, self).__len__()
@@ -1459,7 +1446,7 @@ def __getitem__(self, item):
14591446
# Iterate through server-side pages until we've loaded
14601447
# the item at the requested index.
14611448
while super(PagedList, self).__len__() <= idx:
1462-
n = next(self._pager)
1449+
n = next(self._paged_items)
14631450
self.append(n)
14641451

14651452
except StopIteration:
@@ -1469,13 +1456,13 @@ def __getitem__(self, item):
14691456
# Get the item from the list
14701457
return super(PagedList, self).__getitem__(item)
14711458

1472-
def __str__(self):
1473-
string = super(PagedList, self).__str__()
1459+
def __repr__(self):
1460+
string = super(PagedList, self).__repr__()
14741461

14751462
# If the list has more "items" than are stored in the underlying list
14761463
# then there are more downloads to make.
14771464
if len(self) - super(PagedList, self).__len__() > 0:
1478-
string = string.rstrip(']') + ', ... ]'
1465+
string = string.rstrip(']') + ', ...]'
14791466

14801467
return string
14811468

tests/cassettes/test_project_with_sas_and_sklearn_regression_models.test.json

Lines changed: 457 additions & 1284 deletions
Large diffs are not rendered by default.

tests/cassettes/test_project_with_sas_and_sklearn_regression_models.test_swat.json

Lines changed: 160 additions & 160 deletions
Large diffs are not rendered by default.

tests/scenarios/test_project_with_sas_and_sklearn_regression_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def test(cas_session, boston_dataset):
6464
sas_model = register_model(astore, SAS_MODEL_NAME, PROJECT_NAME, force=True)
6565
sk_model = register_model(sk_model, SCIKIT_MODEL_NAME, PROJECT_NAME, input=X)
6666

67+
# Test overwriting model content
68+
mr.add_model_content(sk_model, 'Your mother was a hamster!', 'insult.txt')
69+
mr.add_model_content(sk_model, 'And your father smelt of elderberries!', 'insult.txt')
70+
6771
# Publish to MAS
6872
sas_module = publish_model(sas_model, 'maslocal', replace=True)
6973
sk_module = publish_model(sk_model, 'maslocal', replace=True)

tests/unit/test_pagedlist.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@ def test_len_no_paging():
1515
items = [{'name': 'a'}, {'name': 'b'}, {'name': 'c'}]
1616
obj = RestObj(items=items, count=len(items))
1717

18+
# PagedList should end up effectively identical to a standard list since no paging required.
19+
target = [RestObj(x) for x in items]
20+
1821
with mock.patch('sasctl.core.request') as request:
1922
l = PagedList(obj)
23+
assert str(l) == str(target)
24+
assert repr(l) == repr(target)
2025
assert len(l) == 3
2126

2227
for i, o in enumerate(l):
@@ -82,9 +87,9 @@ def side_effect(_, link, **kwargs):
8287
if i < len(source_items) - 1:
8388
# Ellipses should indicate unfetched results unless we're
8489
# at the end of the list
85-
assert str(l).endswith(', ... ]')
90+
assert str(l).endswith(', ...]')
8691
else:
87-
assert not str(l).endswith(', ... ]')
92+
assert not str(l).endswith(', ...]')
8893

8994

9095
def test_getitem_paging(paging):
@@ -95,9 +100,16 @@ def test_getitem_paging(paging):
95100
# length of list should equal total # of items
96101
assert len(l) == len(items)
97102

103+
# If number of items on first page don't match total number of items then
104+
# some paging is required, so repr() should contain elipses indicating more data.
105+
if len(obj['items']) < obj.count:
106+
assert str(l).endswith(', ...]')
107+
98108
for i, item in enumerate(l):
99109
assert item.name == RestObj(items[i]).name
100110

111+
assert not str(l).endswith(', ...]')
112+
101113

102114
def test_get_item_inflated_len():
103115
"""Test behavior when server overestimates the number of items available."""

0 commit comments

Comments
 (0)