Skip to content

Commit a54f46e

Browse files
Fix Interface mapping (#24)
Fixes a bug where an Interface is not properly registered, resulting in an infinite-loop for mapping Interfaces to polymorphic Models. Changes: Properly adds Interfaces to mapped_interfaces Updates deprecated _type_definition references to __strawberry_definition__
1 parent f3deccd commit a54f46e

File tree

3 files changed

+67
-6
lines changed

3 files changed

+67
-6
lines changed

RELEASE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Release type: patch
2+
3+
Fixes a bug where an Interface is not properly registered, resulting in an infinite-loop for mapping Interfaces to polymorphic Models.

src/strawberry_sqlalchemy_mapper/mapper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,11 +683,13 @@ def convert(type_: Any) -> Any:
683683

684684
if make_interface:
685685
mapped_type = strawberry.interface(type_)
686+
self.mapped_interfaces[type_.__name__] = mapped_type
686687
elif use_federation:
687688
mapped_type = strawberry.federation.type(type_)
689+
self.mapped_types[type_.__name__] = mapped_type
688690
else:
689691
mapped_type = strawberry.type(type_)
690-
self.mapped_types[type_.__name__] = mapped_type
692+
self.mapped_types[type_.__name__] = mapped_type
691693
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY, generated_field_keys)
692694
setattr(mapped_type, _ORIGINAL_TYPE_KEY, type_)
693695
return mapped_type
@@ -728,7 +730,7 @@ def _fix_annotation_namespaces(self) -> None:
728730
self.edge_types.values(),
729731
self.connection_types.values(),
730732
):
731-
for field in mapped_type._type_definition.fields:
733+
for field in mapped_type.__strawberry_definition__.fields:
732734
if field.name in getattr(mapped_type, _GENERATED_FIELD_KEYS_KEY):
733735
namespace = {}
734736
if hasattr(mapped_type, _ORIGINAL_TYPE_KEY):

tests/test_mapper.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import enum
22
from typing import List, Optional
33

4+
import pytest
45
from sqlalchemy import Column, Enum, ForeignKey, Integer, String
56
from sqlalchemy.dialects.postgresql.array import ARRAY
67
from sqlalchemy.ext.declarative import declarative_base
@@ -10,6 +11,34 @@
1011
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
1112

1213

14+
@pytest.fixture
15+
def mapper():
16+
return StrawberrySQLAlchemyMapper()
17+
18+
19+
@pytest.fixture
20+
def polymorphic_employee():
21+
Base = declarative_base()
22+
23+
class Employee(Base):
24+
__tablename__ = "employee"
25+
id = Column(Integer, autoincrement=True, primary_key=True)
26+
type = Column(String(50))
27+
name = Column(String(50))
28+
29+
__mapper_args__ = {"polymorphic_identity": "employee", "polymorphic_on": type}
30+
31+
return Employee
32+
33+
34+
@pytest.fixture
35+
def polymorphic_lawyer(polymorphic_employee):
36+
class Lawyer(polymorphic_employee):
37+
__mapper_args__ = {"polymorphic_identity": "lawyer"}
38+
39+
return Lawyer
40+
41+
1342
def _create_employee_table():
1443
# todo: use pytest fixtures
1544
Base = declarative_base()
@@ -233,14 +262,41 @@ class Employee:
233262
assert len(additional_types) == 1
234263
mapped_employee_type = additional_types[0]
235264
assert mapped_employee_type.__name__ == "Employee"
236-
assert len(mapped_employee_type._type_definition._fields) == 2
237-
employee_type_fields = mapped_employee_type._type_definition._fields
265+
assert len(mapped_employee_type.__strawberry_definition__._fields) == 2
266+
employee_type_fields = mapped_employee_type.__strawberry_definition__._fields
238267
name = list(filter(lambda f: f.name == "name", employee_type_fields))[0]
239268
assert name.type == str
240269
id = list(filter(lambda f: f.name == "id", employee_type_fields))[0]
241270
assert id.type == int
242271

243272

273+
def test_interface_and_type_polymorphic(
274+
mapper, polymorphic_employee, polymorphic_lawyer
275+
):
276+
@mapper.interface(polymorphic_employee)
277+
class EmployeeInterface:
278+
pass
279+
280+
@mapper.type(polymorphic_employee)
281+
class Employee:
282+
pass
283+
284+
@mapper.type(polymorphic_lawyer)
285+
class Lawyer:
286+
pass
287+
288+
mapper.finalize()
289+
290+
additional_interfaces = list(mapper.mapped_interfaces.values())
291+
assert len(additional_interfaces) == 1
292+
mapped_employee_interface_type = additional_interfaces[0]
293+
assert mapped_employee_interface_type.__name__ == "EmployeeInterface"
294+
295+
additional_types = list(mapper.mapped_types.values())
296+
assert len(additional_types) == 2
297+
assert {"Employee", "Lawyer"} == {t.__name__ for t in additional_types}
298+
299+
244300
def test_type_relationships():
245301
Employee, _ = _create_employee_and_department_tables()
246302
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()
@@ -254,8 +310,8 @@ class Employee:
254310
assert len(additional_types) == 2
255311
mapped_employee_type = additional_types[0]
256312
assert mapped_employee_type.__name__ == "Employee"
257-
assert len(mapped_employee_type._type_definition._fields) == 4
258-
employee_type_fields = mapped_employee_type._type_definition._fields
313+
assert len(mapped_employee_type.__strawberry_definition__._fields) == 4
314+
employee_type_fields = mapped_employee_type.__strawberry_definition__._fields
259315
name = list(filter(lambda f: f.name == "department_id", employee_type_fields))[0]
260316
assert type(name.type) == StrawberryOptional
261317
id = list(filter(lambda f: f.name == "department", employee_type_fields))[0]

0 commit comments

Comments
 (0)