1
- import asyncio
2
- import unittest
3
-
4
1
import pytest
5
2
from sqlalchemy import Column , ForeignKey , Integer , String , Table
6
3
from sqlalchemy .orm import relationship
7
-
8
4
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyLoader
9
5
10
-
11
6
pytest_plugins = ("pytest_asyncio" ,)
12
7
13
8
14
- def _create_many_to_one_tables (Base ):
15
- class Employee (Base ):
9
+ @pytest .fixture
10
+ def many_to_one_tables (base ):
11
+ class Employee (base ):
16
12
__tablename__ = "employee"
17
13
id = Column (Integer , autoincrement = True , primary_key = True )
18
14
name = Column (String , nullable = False )
19
15
department_id = Column (Integer , ForeignKey ("department.id" ))
20
16
department = relationship ("Department" , back_populates = "employees" )
21
17
22
- class Department (Base ):
18
+ class Department (base ):
23
19
__tablename__ = "department"
24
20
id = Column (Integer , autoincrement = True , primary_key = True )
25
21
name = Column (String , nullable = False )
@@ -28,32 +24,33 @@ class Department(Base):
28
24
return Employee , Department
29
25
30
26
31
- def _create_secondary_tables (Base ):
27
+ @pytest .fixture
28
+ def secondary_tables (base ):
32
29
EmployeeDepartmentJoinTable = Table (
33
30
"employee_department_join_table" ,
34
- Base .metadata ,
31
+ base .metadata ,
35
32
Column ("employee_id" , ForeignKey ("employee.e_id" ), primary_key = True ),
36
- Column ("department_id" , ForeignKey ("department.d_id" ), primary_key = True )
33
+ Column ("department_id" , ForeignKey ("department.d_id" ), primary_key = True ),
37
34
)
38
35
39
- class Employee (Base ):
36
+ class Employee (base ):
40
37
__tablename__ = "employee"
41
38
e_id = Column (Integer , autoincrement = True , primary_key = True )
42
39
name = Column (String , nullable = False )
43
40
departments = relationship (
44
41
"Department" ,
45
42
secondary = "employee_department_join_table" ,
46
- back_populates = "employees"
43
+ back_populates = "employees" ,
47
44
)
48
45
49
- class Department (Base ):
46
+ class Department (base ):
50
47
__tablename__ = "department"
51
48
d_id = Column (Integer , autoincrement = True , primary_key = True )
52
49
name = Column (String , nullable = False )
53
50
employees = relationship (
54
51
"Employee" ,
55
52
secondary = "employee_department_join_table" ,
56
- back_populates = "departments"
53
+ back_populates = "departments" ,
57
54
)
58
55
59
56
return Employee , Department
@@ -66,9 +63,9 @@ def test_loader_init():
66
63
67
64
68
65
@pytest .mark .asyncio
69
- async def test_loader_for (engine , Base , sessionmaker ):
70
- Employee , Department = _create_many_to_one_tables ( Base )
71
- Base .metadata .create_all (engine )
66
+ async def test_loader_for (engine , base , sessionmaker , many_to_one_tables ):
67
+ Employee , Department = many_to_one_tables
68
+ base .metadata .create_all (engine )
72
69
73
70
with sessionmaker () as session :
74
71
e1 = Employee (name = "e1" )
@@ -92,21 +89,31 @@ async def test_loader_for(engine, Base, sessionmaker):
92
89
assert loader ._loop is None
93
90
assert loader .load_fn is not None
94
91
95
- key = tuple ([getattr (e1 , local .key ) for local , _ in Employee .department .property .local_remote_pairs ])
92
+ key = tuple (
93
+ [
94
+ getattr (e1 , local .key )
95
+ for local , _ in Employee .department .property .local_remote_pairs
96
+ ]
97
+ )
96
98
department = await loader .load (key )
97
99
assert department .name == "d2"
98
100
99
101
loader = base_loader .loader_for (Department .employees .property )
100
- key = tuple ([getattr (d2 , local .key ) for local , _ in Department .employees .property .local_remote_pairs ])
102
+ key = tuple (
103
+ [
104
+ getattr (d2 , local .key )
105
+ for local , _ in Department .employees .property .local_remote_pairs
106
+ ]
107
+ )
101
108
employees = await loader .load ((d2 .id ,))
102
109
assert {e .name for e in employees } == {"e1" }
103
110
104
111
105
112
@pytest .mark .xfail
106
113
@pytest .mark .asyncio
107
- async def test_loader_for_secondary (engine , Base , sessionmaker ):
108
- Employee , Department = _create_secondary_tables ( Base )
109
- Base .metadata .create_all (engine )
114
+ async def test_loader_for_secondary (engine , base , sessionmaker , secondary_tables ):
115
+ Employee , Department = secondary_tables
116
+ base .metadata .create_all (engine )
110
117
111
118
with sessionmaker () as session :
112
119
e1 = Employee (name = "e1" )
@@ -124,9 +131,14 @@ async def test_loader_for_secondary(engine, Base, sessionmaker):
124
131
e2 .departments .append (d2 )
125
132
session .commit ()
126
133
127
- base_loader = StrawberrySQLAlchemyLoader (bind = session )
134
+ base_loader = StrawberrySQLAlchemyLoader (bind = session )
128
135
loader = base_loader .loader_for (Employee .departments .property )
129
136
130
- key = tuple ([getattr (e1 , local .key ) for local , _ in Employee .departments .property .local_remote_pairs ])
137
+ key = tuple (
138
+ [
139
+ getattr (e1 , local .key )
140
+ for local , _ in Employee .departments .property .local_remote_pairs
141
+ ]
142
+ )
131
143
departments = await loader .load (key )
132
144
assert {d .name for d in departments } == {"d1" , "d2" }
0 commit comments