Skip to content

Commit 90d913c

Browse files
committed
➕ Unit testing for managers inheritance
1 parent 84d6f54 commit 90d913c

File tree

5 files changed

+202
-20
lines changed

5 files changed

+202
-20
lines changed

saffier/db/queryset.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,9 @@ async def get_or_none(self, **kwargs):
440440
return self.model_class._from_row(rows[0], select_related=self._select_related)
441441

442442
async def all(self, **kwargs):
443+
"""
444+
Returns the queryset records based on specific filters
445+
"""
443446
queryset = self._clone()
444447
if kwargs:
445448
return await queryset.filter(**kwargs).all()
@@ -452,6 +455,9 @@ async def all(self, **kwargs):
452455
]
453456

454457
async def get(self, **kwargs):
458+
"""
459+
Returns a single record based on the given kwargs.
460+
"""
455461
if kwargs:
456462
return await self.filter(**kwargs).get()
457463

@@ -465,6 +471,9 @@ async def get(self, **kwargs):
465471
return self.model_class._from_row(rows[0], select_related=self._select_related)
466472

467473
async def first(self, **kwargs):
474+
"""
475+
Returns the first record of a given queryset.
476+
"""
468477
queryset = self._clone()
469478
if kwargs:
470479
return await queryset.filter(**kwargs).order_by("id").get()
@@ -474,6 +483,9 @@ async def first(self, **kwargs):
474483
return rows[0]
475484

476485
async def last(self, **kwargs):
486+
"""
487+
Returns the last record of a given queryset.
488+
"""
477489
queryset = self._clone()
478490
if kwargs:
479491
return await queryset.filter(**kwargs).order_by("-id").get()
@@ -483,6 +495,9 @@ async def last(self, **kwargs):
483495
return rows[0]
484496

485497
async def create(self, **kwargs):
498+
"""
499+
Creates a record in a specific table.
500+
"""
486501
kwargs = self._validate_kwargs(**kwargs)
487502
instance = self.model_class(**kwargs)
488503
expression = self.table.insert().values(**kwargs)
@@ -495,6 +510,9 @@ async def create(self, **kwargs):
495510
return instance
496511

497512
async def bulk_create(self, objs: typing.List[typing.Dict]) -> None:
513+
"""
514+
Bulk creates records in a table
515+
"""
498516
new_objs = [self._validate_kwargs(**obj) for obj in objs]
499517

500518
expression = self.table.insert().values(new_objs)
@@ -508,6 +526,9 @@ async def delete(self) -> None:
508526
await self.database.execute(expression)
509527

510528
async def update(self, **kwargs) -> None:
529+
"""
530+
Updates a record in a specific table with the given kwargs.
531+
"""
511532
fields = {
512533
key: field.validator for key, field in self.model_class.fields.items() if key in kwargs
513534
}
@@ -524,6 +545,9 @@ async def update(self, **kwargs) -> None:
524545
async def get_or_create(
525546
self, defaults: typing.Dict[str, typing.Any], **kwargs
526547
) -> typing.Tuple[typing.Any, bool]:
548+
"""
549+
Creates a record in a specific table or updates if already exists.
550+
"""
527551
try:
528552
instance = await self.get(**kwargs)
529553
return instance, False
@@ -535,6 +559,9 @@ async def get_or_create(
535559
async def update_or_create(
536560
self, defaults: typing.Dict[str, typing.Any], **kwargs
537561
) -> typing.Tuple[typing.Any, bool]:
562+
"""
563+
Updates a record in a specific table or creates a new one.
564+
"""
538565
try:
539566
instance = await self.get(**kwargs)
540567
await instance.update(**defaults)

saffier/metaclass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _check_manager_for_bases(
8989
if isinstance(value, Manager) and key not in attrs:
9090
attrs[key] = value.__class__()
9191
else:
92-
if not getattr(meta, "abstract", False):
92+
if not meta.abstract:
9393
for key, value in inspect.getmembers(base):
9494
if isinstance(value, Manager) and key not in attrs:
9595
attrs[key] = value.__class__()
@@ -123,7 +123,7 @@ def __search_for_fields(base: typing.Type, attrs: DictAny) -> None:
123123
meta: MetaInfo = getattr(base, "_meta", None)
124124
if not meta:
125125
# Mixins and other classes
126-
for key, value in base.__dict__.items():
126+
for key, value in inspect.getmembers(base):
127127
if isinstance(value, Field) and key not in attrs:
128128
attrs[key] = value
129129

tests/test_managers.py renamed to tests/managers/test_managers.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class Meta:
4040

4141

4242
@pytest.fixture(autouse=True, scope="function")
43-
async def create_xtest_database():
43+
async def create_test_database():
4444
await models.create_all()
4545
yield
4646
await models.drop_all()
@@ -92,20 +92,3 @@ async def test_model_crud_different_manager():
9292
assert len(products) == 1
9393

9494
assert products[0].pk == product.pk
95-
96-
97-
async def test_model_crud_different_manager_create():
98-
products = await Product.active.all()
99-
assert products == []
100-
101-
await Product.active.create(name="One", in_stock=True, is_active=False, rating=5)
102-
await Product.active.create(name="Two", in_stock=True, is_active=False, rating=2)
103-
product = await Product.query.create(name="Three", in_stock=True, is_active=True, rating=3)
104-
105-
products = await Product.query.all()
106-
assert len(products) == 3
107-
108-
products = await Product.active.all()
109-
assert len(products) == 1
110-
111-
assert products[0].pk == product.pk
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import pytest
2+
from tests.settings import DATABASE_URL
3+
4+
import saffier
5+
from saffier import Manager
6+
from saffier.db.connection import Database
7+
from saffier.db.queryset import QuerySet
8+
9+
database = Database(url=DATABASE_URL)
10+
models = saffier.Registry(database=database)
11+
12+
pytestmark = pytest.mark.anyio
13+
14+
15+
class ObjectsManager(Manager):
16+
def get_queryset(self) -> QuerySet:
17+
queryset = super().get_queryset().filter(is_active=True)
18+
return queryset
19+
20+
21+
class LanguageManager(Manager):
22+
def get_queryset(self) -> QuerySet:
23+
queryset = super().get_queryset().filter(language="EN")
24+
return queryset
25+
26+
27+
class BaseModel(saffier.Model):
28+
query = ObjectsManager()
29+
30+
class Meta:
31+
abstract = True
32+
registry = models
33+
34+
35+
class HubUser(BaseModel):
36+
name = saffier.CharField(max_length=100)
37+
language = saffier.CharField(max_length=200, null=True)
38+
39+
languages = LanguageManager()
40+
41+
class Meta:
42+
registry = models
43+
44+
45+
class HubProduct(BaseModel):
46+
name = saffier.CharField(max_length=100)
47+
rating = saffier.IntegerField(minimum=1, maximum=5)
48+
in_stock = saffier.BooleanField(default=False)
49+
is_active = saffier.BooleanField(default=False)
50+
51+
52+
@pytest.fixture(autouse=True, scope="function")
53+
async def create_test_database():
54+
await models.create_all()
55+
yield
56+
await models.drop_all()
57+
58+
59+
@pytest.fixture(autouse=True)
60+
async def rollback_connections():
61+
with database.force_rollback():
62+
async with database:
63+
yield
64+
65+
66+
async def test_inherited_abstract_base_model_managers():
67+
await HubUser.query.create(name="test", language="EN")
68+
await HubUser.query.create(name="test2", language="EN")
69+
await HubUser.query.create(name="test3", language="PT")
70+
await HubUser.query.create(name="test4", language="PT")
71+
72+
# users = await HubUser.query.all()
73+
# assert len(users) == 4
74+
75+
users = await HubUser.languages.all()
76+
assert len(users) == 2
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import pytest
2+
from tests.settings import DATABASE_URL
3+
4+
import saffier
5+
from saffier import Manager
6+
from saffier.db.connection import Database
7+
from saffier.db.queryset import QuerySet
8+
9+
database = Database(url=DATABASE_URL)
10+
models = saffier.Registry(database=database)
11+
12+
pytestmark = pytest.mark.anyio
13+
14+
15+
class ObjectsManager(Manager):
16+
def get_queryset(self) -> QuerySet:
17+
queryset = super().get_queryset().filter(is_active=True)
18+
return queryset
19+
20+
21+
class LanguageManager(Manager):
22+
def get_queryset(self) -> QuerySet:
23+
queryset = super().get_queryset().filter(language="EN")
24+
return queryset
25+
26+
27+
class RatingManager(Manager):
28+
def get_queryset(self) -> QuerySet:
29+
queryset = super().get_queryset().filter(rating__gte=3)
30+
return queryset
31+
32+
33+
class BaseModel(saffier.Model):
34+
query = ObjectsManager()
35+
languages = LanguageManager()
36+
ratings = RatingManager()
37+
38+
class Meta:
39+
registry = models
40+
41+
42+
class User(BaseModel):
43+
name = saffier.CharField(max_length=100)
44+
language = saffier.CharField(max_length=200, null=True)
45+
46+
class Meta:
47+
registry = models
48+
49+
50+
class Product(BaseModel):
51+
name = saffier.CharField(max_length=100)
52+
rating = saffier.IntegerField(minimum=1, maximum=5)
53+
in_stock = saffier.BooleanField(default=False)
54+
is_active = saffier.BooleanField(default=False)
55+
56+
57+
@pytest.fixture(autouse=True, scope="function")
58+
async def create_test_database():
59+
await models.create_all()
60+
yield
61+
await models.drop_all()
62+
63+
64+
@pytest.fixture(autouse=True)
65+
async def rollback_connections():
66+
with database.force_rollback():
67+
async with database:
68+
yield
69+
70+
71+
async def test_inherited_base_model_managers():
72+
await User.query.create(name="test", language="EN")
73+
await User.query.create(name="test2", language="EN")
74+
await User.query.create(name="test3", language="PT")
75+
await User.query.create(name="test4", language="PT")
76+
77+
users = await User.query.all()
78+
assert len(users) == 4
79+
80+
users = await User.languages.all()
81+
assert len(users) == 2
82+
83+
84+
async def test_inherited_base_model_managers_product():
85+
await Product.query.create(name="test", rating=5)
86+
await Product.query.create(name="test2", rating=4)
87+
await Product.query.create(name="test3", rating=3)
88+
await Product.query.create(name="test4", rating=2)
89+
await Product.query.create(name="test5", rating=2)
90+
await Product.query.create(name="test6", rating=1)
91+
92+
users = await Product.query.all()
93+
assert len(users) == 6
94+
95+
users = await Product.ratings.all()
96+
assert len(users) == 3

0 commit comments

Comments
 (0)