Skip to content

Commit 668deb0

Browse files
committed
wrote up to add annotation testing
1 parent 08f774f commit 668deb0

File tree

1 file changed

+137
-12
lines changed

1 file changed

+137
-12
lines changed

tests/test_mapper.py

Lines changed: 137 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,187 @@
1-
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
1+
import enum
2+
from typing import List, Optional
3+
4+
from sqlalchemy import Column, Enum, ForeignKey, Integer, String
5+
from sqlalchemy.dialects.postgresql.array import ARRAY
26
from sqlalchemy.ext.declarative import declarative_base
3-
from sqlalchemy import Column, Integer, String
7+
from sqlalchemy.orm import relationship
8+
9+
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
10+
411

512
def _create_employee_table():
613
# todo: use pytest fixtures
714
Base = declarative_base()
15+
816
class Employee(Base):
917
__tablename__ = "employee"
1018
id = Column(Integer, autoincrement=True, primary_key=True)
1119
name = Column(String, nullable=False)
20+
1221
return Employee
1322

23+
24+
def _create_employee_and_department_tables():
25+
# todo: use pytest fixtures
26+
Base = declarative_base()
27+
28+
class Employee(Base):
29+
__tablename__ = "employee"
30+
id = Column(Integer, autoincrement=True, primary_key=True)
31+
name = Column(String, nullable=False)
32+
department_id = Column(Integer, ForeignKey("department.id"))
33+
department = relationship("Department", back_populates="employees")
34+
35+
class Department(Base):
36+
__tablename__ = "department"
37+
id = Column(Integer, autoincrement=True, primary_key=True)
38+
name = Column(String, nullable=False)
39+
employees = relationship("Employee", back_populates="department")
40+
41+
return Employee, Department
42+
43+
1444
def _create_polymorphic_employee_table():
45+
# todo: use pytest fixtures
1546
Base = declarative_base()
47+
1648
class Employee(Base):
1749
__tablename__ = "employee"
1850
id = Column(Integer, autoincrement=True, primary_key=True)
1951
type = Column(String(50))
2052

21-
__mapper_args__ = {
22-
'polymorphic_identity':'employee',
23-
'polymorphic_on':type
24-
}
53+
__mapper_args__ = {"polymorphic_identity": "employee", "polymorphic_on": type}
54+
2555
return Employee
2656

2757

2858
def test_mapper_default_model_to_type_name():
2959
Employee = _create_employee_table()
30-
assert StrawberrySQLAlchemyMapper._default_model_to_type_name(Employee) == "Employee"
60+
assert (
61+
StrawberrySQLAlchemyMapper._default_model_to_type_name(Employee) == "Employee"
62+
)
63+
3164

3265
def test_default_model_to_interface_name():
3366
Employee = _create_employee_table()
34-
assert StrawberrySQLAlchemyMapper._default_model_to_interface_name(Employee) == "EmployeeInterface"
67+
assert (
68+
StrawberrySQLAlchemyMapper._default_model_to_interface_name(Employee)
69+
== "EmployeeInterface"
70+
)
71+
3572

3673
def test_model_is_interface_fails():
3774
Employee = _create_employee_table()
3875
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
39-
assert strawberry_sqlalchemy_mapper.model_is_interface(Employee)is False
76+
assert strawberry_sqlalchemy_mapper.model_is_interface(Employee) is False
77+
4078

4179
def test_model_is_interface_succeeds():
4280
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
4381
Employee = _create_polymorphic_employee_table()
4482
assert strawberry_sqlalchemy_mapper.model_is_interface(Employee) is True
4583

84+
4685
def test_is_model_polymorphic():
4786
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
4887
Employee = _create_polymorphic_employee_table()
4988
assert strawberry_sqlalchemy_mapper._is_model_polymorphic(Employee) is True
5089

90+
5191
def test_edge_type_for():
5292
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
5393
employee_edge_class = strawberry_sqlalchemy_mapper._edge_type_for("Employee")
5494
assert employee_edge_class.__name__ == "EmployeeEdge"
55-
assert employee_edge_class._generated_field_keys == ['node']
95+
assert employee_edge_class._generated_field_keys == ["node"]
96+
5697

5798
def test_connection_type_for():
5899
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
59-
employee_connection_class = strawberry_sqlalchemy_mapper._connection_type_for("Employee")
100+
employee_connection_class = strawberry_sqlalchemy_mapper._connection_type_for(
101+
"Employee"
102+
)
60103
assert employee_connection_class.__name__ == "EmployeeConnection"
61-
assert employee_connection_class._generated_field_keys == ['edges']
104+
assert employee_connection_class._generated_field_keys == ["edges"]
62105
assert employee_connection_class._is_generated_connection_type is True
106+
107+
108+
def test_get_polymorphic_base_model():
109+
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
110+
Employee = _create_polymorphic_employee_table()
111+
112+
class Lawyer(Employee):
113+
pass
114+
115+
class ParaLegal(Lawyer):
116+
pass
117+
118+
assert (
119+
strawberry_sqlalchemy_mapper._get_polymorphic_base_model(Employee) == Employee
120+
)
121+
assert strawberry_sqlalchemy_mapper._get_polymorphic_base_model(Lawyer) == Employee
122+
assert (
123+
strawberry_sqlalchemy_mapper._get_polymorphic_base_model(ParaLegal) == Employee
124+
)
125+
126+
127+
def test_convert_column_to_strawberry_type():
128+
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
129+
int_column = Column(Integer, nullable=False)
130+
assert (
131+
strawberry_sqlalchemy_mapper._convert_column_to_strawberry_type(int_column)
132+
== int
133+
)
134+
string_column = Column(String, nullable=False)
135+
assert (
136+
strawberry_sqlalchemy_mapper._convert_column_to_strawberry_type(string_column)
137+
== str
138+
)
139+
140+
141+
def test_convert_array_column_to_strawberry_type():
142+
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
143+
column = Column(ARRAY(String))
144+
assert (
145+
strawberry_sqlalchemy_mapper._convert_column_to_strawberry_type(column)
146+
== Optional[List[str]]
147+
)
148+
column = Column(ARRAY(String), nullable=False)
149+
assert (
150+
strawberry_sqlalchemy_mapper._convert_column_to_strawberry_type(column)
151+
== List[str]
152+
)
153+
154+
155+
def test_convert_enum_column_to_strawberry_type():
156+
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
157+
158+
class SampleEnum(enum.Enum):
159+
one = 1
160+
two = 2
161+
three = 3
162+
163+
column = Column(Enum(SampleEnum))
164+
assert (
165+
strawberry_sqlalchemy_mapper._convert_column_to_strawberry_type(column)
166+
== Optional[SampleEnum]
167+
)
168+
column = Column(Enum(SampleEnum), nullable=False)
169+
assert (
170+
strawberry_sqlalchemy_mapper._convert_column_to_strawberry_type(column)
171+
== SampleEnum
172+
)
173+
174+
175+
def test_add_annotation():
176+
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
177+
178+
class Base:
179+
a: int = 3
180+
b: str = "abc"
181+
182+
field_keys = []
183+
key = "name"
184+
annotation = "base_name"
185+
strawberry_sqlalchemy_mapper._add_annotation(Base, key, annotation, field_keys)
186+
assert Base.__annotations__[key] == annotation
187+
assert field_keys == [key]

0 commit comments

Comments
 (0)