Skip to content

Commit dd02f9f

Browse files
authored
Refactor tests to use pytest.fixtures (#42)
refactor test_mapper and test_loader to use pytest.fixtures
1 parent 479c50d commit dd02f9f

File tree

3 files changed

+111
-155
lines changed

3 files changed

+111
-155
lines changed

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,5 @@ def sessionmaker(engine) -> orm.sessionmaker:
7373

7474

7575
@pytest.fixture
76-
def Base():
76+
def base():
7777
return orm.declarative_base()

tests/test_loader.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
1-
import asyncio
2-
import unittest
3-
41
import pytest
52
from sqlalchemy import Column, ForeignKey, Integer, String, Table
63
from sqlalchemy.orm import relationship
7-
84
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyLoader
95

10-
116
pytest_plugins = ("pytest_asyncio",)
127

138

14-
def _create_many_to_one_tables(Base):
15-
class Employee(Base):
9+
@pytest.fixture
10+
def many_to_one_tables(base):
11+
class Employee(base):
1612
__tablename__ = "employee"
1713
id = Column(Integer, autoincrement=True, primary_key=True)
1814
name = Column(String, nullable=False)
1915
department_id = Column(Integer, ForeignKey("department.id"))
2016
department = relationship("Department", back_populates="employees")
2117

22-
class Department(Base):
18+
class Department(base):
2319
__tablename__ = "department"
2420
id = Column(Integer, autoincrement=True, primary_key=True)
2521
name = Column(String, nullable=False)
@@ -28,32 +24,33 @@ class Department(Base):
2824
return Employee, Department
2925

3026

31-
def _create_secondary_tables(Base):
27+
@pytest.fixture
28+
def secondary_tables(base):
3229
EmployeeDepartmentJoinTable = Table(
3330
"employee_department_join_table",
34-
Base.metadata,
31+
base.metadata,
3532
Column("employee_id", ForeignKey("employee.e_id"), primary_key=True),
36-
Column("department_id", ForeignKey("department.d_id"), primary_key=True)
33+
Column("department_id", ForeignKey("department.d_id"), primary_key=True),
3734
)
3835

39-
class Employee(Base):
36+
class Employee(base):
4037
__tablename__ = "employee"
4138
e_id = Column(Integer, autoincrement=True, primary_key=True)
4239
name = Column(String, nullable=False)
4340
departments = relationship(
4441
"Department",
4542
secondary="employee_department_join_table",
46-
back_populates="employees"
43+
back_populates="employees",
4744
)
4845

49-
class Department(Base):
46+
class Department(base):
5047
__tablename__ = "department"
5148
d_id = Column(Integer, autoincrement=True, primary_key=True)
5249
name = Column(String, nullable=False)
5350
employees = relationship(
5451
"Employee",
5552
secondary="employee_department_join_table",
56-
back_populates="departments"
53+
back_populates="departments",
5754
)
5855

5956
return Employee, Department
@@ -66,9 +63,9 @@ def test_loader_init():
6663

6764

6865
@pytest.mark.asyncio
69-
async def test_loader_for(engine, Base, sessionmaker):
70-
Employee, Department = _create_many_to_one_tables(Base)
71-
Base.metadata.create_all(engine)
66+
async def test_loader_for(engine, base, sessionmaker, many_to_one_tables):
67+
Employee, Department = many_to_one_tables
68+
base.metadata.create_all(engine)
7269

7370
with sessionmaker() as session:
7471
e1 = Employee(name="e1")
@@ -92,21 +89,31 @@ async def test_loader_for(engine, Base, sessionmaker):
9289
assert loader._loop is None
9390
assert loader.load_fn is not None
9491

95-
key = tuple([getattr(e1, local.key) for local, _ in Employee.department.property.local_remote_pairs])
92+
key = tuple(
93+
[
94+
getattr(e1, local.key)
95+
for local, _ in Employee.department.property.local_remote_pairs
96+
]
97+
)
9698
department = await loader.load(key)
9799
assert department.name == "d2"
98100

99101
loader = base_loader.loader_for(Department.employees.property)
100-
key = tuple([getattr(d2, local.key) for local, _ in Department.employees.property.local_remote_pairs])
102+
key = tuple(
103+
[
104+
getattr(d2, local.key)
105+
for local, _ in Department.employees.property.local_remote_pairs
106+
]
107+
)
101108
employees = await loader.load((d2.id,))
102109
assert {e.name for e in employees} == {"e1"}
103110

104111

105112
@pytest.mark.xfail
106113
@pytest.mark.asyncio
107-
async def test_loader_for_secondary(engine, Base, sessionmaker):
108-
Employee, Department = _create_secondary_tables(Base)
109-
Base.metadata.create_all(engine)
114+
async def test_loader_for_secondary(engine, base, sessionmaker, secondary_tables):
115+
Employee, Department = secondary_tables
116+
base.metadata.create_all(engine)
110117

111118
with sessionmaker() as session:
112119
e1 = Employee(name="e1")
@@ -124,9 +131,14 @@ async def test_loader_for_secondary(engine, Base, sessionmaker):
124131
e2.departments.append(d2)
125132
session.commit()
126133

127-
base_loader=StrawberrySQLAlchemyLoader(bind=session)
134+
base_loader = StrawberrySQLAlchemyLoader(bind=session)
128135
loader = base_loader.loader_for(Employee.departments.property)
129136

130-
key = tuple([getattr(e1, local.key) for local, _ in Employee.departments.property.local_remote_pairs])
137+
key = tuple(
138+
[
139+
getattr(e1, local.key)
140+
for local, _ in Employee.departments.property.local_remote_pairs
141+
]
142+
)
131143
departments = await loader.load(key)
132144
assert {d.name for d in departments} == {"d1", "d2"}

0 commit comments

Comments
 (0)