1
1
import enum
2
2
from typing import List , Optional
3
3
4
+ import pytest
4
5
from sqlalchemy import Column , Enum , ForeignKey , Integer , String
5
6
from sqlalchemy .dialects .postgresql .array import ARRAY
6
7
from sqlalchemy .ext .declarative import declarative_base
10
11
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
11
12
12
13
14
+ @pytest .fixture
15
+ def mapper ():
16
+ return StrawberrySQLAlchemyMapper ()
17
+
18
+
19
+ @pytest .fixture
20
+ def polymorphic_employee ():
21
+ Base = declarative_base ()
22
+
23
+ class Employee (Base ):
24
+ __tablename__ = "employee"
25
+ id = Column (Integer , autoincrement = True , primary_key = True )
26
+ type = Column (String (50 ))
27
+ name = Column (String (50 ))
28
+
29
+ __mapper_args__ = {"polymorphic_identity" : "employee" , "polymorphic_on" : type }
30
+
31
+ return Employee
32
+
33
+
34
+ @pytest .fixture
35
+ def polymorphic_lawyer (polymorphic_employee ):
36
+ class Lawyer (polymorphic_employee ):
37
+ __mapper_args__ = {"polymorphic_identity" : "lawyer" }
38
+
39
+ return Lawyer
40
+
41
+
13
42
def _create_employee_table ():
14
43
# todo: use pytest fixtures
15
44
Base = declarative_base ()
@@ -233,14 +262,41 @@ class Employee:
233
262
assert len (additional_types ) == 1
234
263
mapped_employee_type = additional_types [0 ]
235
264
assert mapped_employee_type .__name__ == "Employee"
236
- assert len (mapped_employee_type ._type_definition ._fields ) == 2
237
- employee_type_fields = mapped_employee_type ._type_definition ._fields
265
+ assert len (mapped_employee_type .__strawberry_definition__ ._fields ) == 2
266
+ employee_type_fields = mapped_employee_type .__strawberry_definition__ ._fields
238
267
name = list (filter (lambda f : f .name == "name" , employee_type_fields ))[0 ]
239
268
assert name .type == str
240
269
id = list (filter (lambda f : f .name == "id" , employee_type_fields ))[0 ]
241
270
assert id .type == int
242
271
243
272
273
+ def test_interface_and_type_polymorphic (
274
+ mapper , polymorphic_employee , polymorphic_lawyer
275
+ ):
276
+ @mapper .interface (polymorphic_employee )
277
+ class EmployeeInterface :
278
+ pass
279
+
280
+ @mapper .type (polymorphic_employee )
281
+ class Employee :
282
+ pass
283
+
284
+ @mapper .type (polymorphic_lawyer )
285
+ class Lawyer :
286
+ pass
287
+
288
+ mapper .finalize ()
289
+
290
+ additional_interfaces = list (mapper .mapped_interfaces .values ())
291
+ assert len (additional_interfaces ) == 1
292
+ mapped_employee_interface_type = additional_interfaces [0 ]
293
+ assert mapped_employee_interface_type .__name__ == "EmployeeInterface"
294
+
295
+ additional_types = list (mapper .mapped_types .values ())
296
+ assert len (additional_types ) == 2
297
+ assert {"Employee" , "Lawyer" } == {t .__name__ for t in additional_types }
298
+
299
+
244
300
def test_type_relationships ():
245
301
Employee , _ = _create_employee_and_department_tables ()
246
302
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper ()
@@ -254,8 +310,8 @@ class Employee:
254
310
assert len (additional_types ) == 2
255
311
mapped_employee_type = additional_types [0 ]
256
312
assert mapped_employee_type .__name__ == "Employee"
257
- assert len (mapped_employee_type ._type_definition ._fields ) == 4
258
- employee_type_fields = mapped_employee_type ._type_definition ._fields
313
+ assert len (mapped_employee_type .__strawberry_definition__ ._fields ) == 4
314
+ employee_type_fields = mapped_employee_type .__strawberry_definition__ ._fields
259
315
name = list (filter (lambda f : f .name == "department_id" , employee_type_fields ))[0 ]
260
316
assert type (name .type ) == StrawberryOptional
261
317
id = list (filter (lambda f : f .name == "department" , employee_type_fields ))[0 ]
0 commit comments