Skip to content

Commit b59e6a4

Browse files
Improve support for type inheritance from other mapped types (#253)
* add test_inheritance_table * fix inheritance with mapped types * remove mapped_collum to work with sqlalchemy 1.4 * fix schema expected * adding new fixes to python 3.8 and 3.9 * Update docs and release * Update RELEASE.md Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Update README.md Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Update src/strawberry_sqlalchemy_mapper/mapper.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * adding fixtures --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
1 parent ea37387 commit b59e6a4

File tree

4 files changed

+525
-14
lines changed

4 files changed

+525
-14
lines changed

README.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,46 @@ query {
152152
"""
153153
```
154154

155+
### Type Inheritance
156+
157+
You can inherit fields from other mapped types using standard Python class inheritance.
158+
159+
- Fields from the parent type (e.g., ApiA) are inherited by the child (e.g., ApiB).
160+
161+
- The `__exclude__` setting applies to inherited fields.
162+
163+
- If both SQLAlchemy models define the same field name, the field from the model inside `.type(...)` takes precedence.
164+
165+
- Declaring a field manually in the mapped type overrides everything else.
166+
167+
```python
168+
class ModelA(base):
169+
__tablename__ = "a"
170+
171+
id = Column(String, primary_key=True)
172+
common_field = Column(String(50))
173+
174+
175+
class ModelB(base):
176+
__tablename__ = "b"
177+
178+
id = Column(String, primary_key=True)
179+
common_field = Column(Integer) # Conflicting field
180+
extra_field = Column(String(50))
181+
182+
183+
@mapper.type(ModelA)
184+
class ApiA:
185+
__exclude__ = ["id"] # This field will be excluded in ApiA (and its children)
186+
187+
188+
@mapper.type(ModelB)
189+
class ApiB(ApiA):
190+
# Inherits fields from ApiA, except "id"
191+
# "common_field" will come from ModelB, not ModelA, so it will be a Integer
192+
# "extra_field" will be overridden and will be a float now instead of the String type declared in ModelB:
193+
extra_field: float = strawberry.field(name="extraField")
194+
```
155195
## Limitations
156196

157197
### Supported Types

RELEASE.md

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
Release type: patch
2+
3+
This release improves how types inherit fields from other mapped types using `@mapper.type(...)`.
4+
You can now safely inherit from another mapped type, and the resulting GraphQL type will include all expected fields with predictable conflict resolution.
5+
6+
Some examples:
7+
8+
- Basic Inheritance:
9+
10+
```python
11+
@mapper.type(ModelA)
12+
class ApiA:
13+
pass
14+
15+
16+
@mapper.type(ModelB)
17+
class ApiB(ApiA):
18+
# ApiB inherits all fields declared in ApiA
19+
pass
20+
```
21+
22+
23+
- The `__exclude__` option continues working:
24+
25+
```python
26+
@mapper.type(ModelA)
27+
class ApiA:
28+
__exclude__ = ["relationshipB_id"]
29+
30+
31+
@mapper.type(ModelB)
32+
class ApiB(ApiA):
33+
# ApiB will have all fields declared in ApiA, except "relationshipB_id"
34+
pass
35+
```
36+
37+
- If two SQLAlchemy models define fields with the same name, the field from the model inside `.type(...)` takes precedence:
38+
39+
```python
40+
class ModelA(base):
41+
__tablename__ = "a"
42+
43+
id = Column(String, primary_key=True)
44+
example_field = Column(String(50))
45+
46+
47+
class ModelB(base):
48+
__tablename__ = "b"
49+
50+
id = Column(String, primary_key=True)
51+
example_field = Column(Integer, autoincrement=True)
52+
53+
54+
@mapper.type(ModelA)
55+
class ApiA:
56+
# example_field will be a String
57+
pass
58+
59+
60+
@mapper.type(ModelB)
61+
class ApiB(ApiA):
62+
# example_field will be taken from ModelB and will be an Integer
63+
pass
64+
```
65+
66+
67+
- If a field is explicitly declared in the mapped type, it will override any inherited or model-based definition:
68+
69+
```python
70+
class ModelA(base):
71+
__tablename__ = "a"
72+
73+
id = Column(String, primary_key=True)
74+
example_field = Column(String(50))
75+
76+
77+
class ModelB(base):
78+
__tablename__ = "b"
79+
80+
id = Column(String, primary_key=True)
81+
example_field = Column(Integer, autoincrement=True)
82+
83+
84+
@mapper.type(ModelA)
85+
class ApiA:
86+
pass
87+
88+
89+
@mapper.type(ModelB)
90+
class ApiB(ApiA):
91+
# example_field will be a Float
92+
example_field: float = strawberry.field(name="exampleField")
93+
```

src/strawberry_sqlalchemy_mapper/mapper.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
Protocol,
2828
Sequence,
2929
Set,
30+
Tuple,
3031
Type,
3132
TypeVar,
3233
Union,
3334
cast,
35+
get_type_hints,
3436
overload,
3537
)
3638
from typing_extensions import Self
@@ -651,27 +653,47 @@ class Employee:
651653
```
652654
"""
653655

656+
def _get_generated_field_keys(type_, old_annotations) -> Tuple[List[str], Dict[str, Any]]:
657+
old_annotations = old_annotations.copy()
658+
generated_field_keys = set()
659+
660+
for key in dir(type_):
661+
val = getattr(type_, key)
662+
if getattr(val, _IS_GENERATED_RESOLVER_KEY, False):
663+
setattr(type_, key, field(resolver=val))
664+
generated_field_keys.add(key)
665+
666+
# Checks for an original type annotation, useful in resolving inheritance-related types
667+
if original_type := getattr(type_, _ORIGINAL_TYPE_KEY, None):
668+
for key in dir(original_type):
669+
if key.startswith("__") and key.endswith("__"):
670+
continue
671+
672+
val = getattr(original_type, key)
673+
if getattr(val, _IS_GENERATED_RESOLVER_KEY, False):
674+
setattr(type_, key, field(resolver=val))
675+
generated_field_keys.add(key)
676+
try:
677+
annotations = get_type_hints(original_type)
678+
except Exception:
679+
annotations = original_type.__annotations__
680+
681+
if key in annotations:
682+
old_annotations[key] = annotations[key]
683+
684+
return list(generated_field_keys), old_annotations
685+
654686
def convert(type_: Any) -> Any:
655687
old_annotations = getattr(type_, "__annotations__", {})
656688
type_.__annotations__ = {k: v for k, v in old_annotations.items() if is_private(v)}
657689
mapper: Mapper = cast("Mapper", inspect(model))
658-
generated_field_keys = []
659690

660691
excluded_keys = getattr(type_, "__exclude__", [])
661692
list_keys = getattr(type_, "__use_list__", [])
662693

663-
# if the type inherits from another mapped type, then it may have
664-
# generated resolvers. These will be treated by dataclasses as having
665-
# a default value, which will likely cause issues because of keys
666-
# that don't have default values. To fix this, we wrap them in
667-
# `strawberry.field()` (like when they were originally made), so
668-
# dataclasses will ignore them.
669-
# TODO: Potentially raise/fix this issue upstream
670-
for key in dir(type_):
671-
val = getattr(type_, key)
672-
if getattr(val, _IS_GENERATED_RESOLVER_KEY, False):
673-
setattr(type_, key, field(resolver=val))
674-
generated_field_keys.append(key)
694+
generated_field_keys, old_annotations = _get_generated_field_keys(
695+
type_, old_annotations
696+
)
675697

676698
self._handle_columns(mapper, type_, excluded_keys, generated_field_keys)
677699
relationship: RelationshipProperty
@@ -798,7 +820,15 @@ def convert(type_: Any) -> Any:
798820
# because the pre-existing fields might have default values,
799821
# which will cause the mapped fields to fail
800822
# (because they may not have default values)
801-
type_.__annotations__.update(old_annotations)
823+
824+
# For Python versions <= 3.9, only update annotations that don't already exist
825+
# because this versions handle inherance differently
826+
if sys.version_info[:2] <= (3, 9):
827+
for k, v in old_annotations.items():
828+
if k not in type_.__annotations__:
829+
type_.__annotations__[k] = v
830+
else:
831+
type_.__annotations__.update(old_annotations)
802832

803833
if make_interface:
804834
mapped_type = strawberry.interface(type_)

0 commit comments

Comments
 (0)