Skip to content

Commit 47147b7

Browse files
committed
Add support for polymorphic foreign key optimization
1 parent 3956238 commit 47147b7

File tree

4 files changed

+77
-7
lines changed

4 files changed

+77
-7
lines changed

strawberry_django/optimizer.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,22 @@ def _get_hints_from_model_property(
702702
return store
703703

704704

705+
def _must_use_prefetch_related(
706+
field: StrawberryField,
707+
model_field: models.ForeignKey | OneToOneRel,
708+
) -> bool:
709+
f_type = _get_django_type(field)
710+
if f_type and hasattr(f_type, "get_queryset"):
711+
# If the field has a get_queryset method, change strategy to Prefetch
712+
# so it will be respected
713+
return True
714+
if is_polymorphic_model(model_field.related_model):
715+
# If the model is using django-polymorphic, change strategy to Prefetch,
716+
# so its custom queryset will be used, returning polymorphic models
717+
return True
718+
return False
719+
720+
705721
def _get_hints_from_django_foreign_key(
706722
field: StrawberryField,
707723
field_definition: GraphQLObjectType,
@@ -717,13 +733,9 @@ def _get_hints_from_django_foreign_key(
717733
cache: dict[type[models.Model], list[tuple[int, OptimizerStore]]],
718734
level: int = 0,
719735
) -> OptimizerStore:
720-
f_type = _get_django_type(field)
721-
if f_type and hasattr(f_type, "get_queryset"):
722-
# If the field has a get_queryset method, change strategy to Prefetch
723-
# so it will be respected
736+
if _must_use_prefetch_related(field, model_field):
724737
store = _get_hints_from_django_relation(
725738
field,
726-
field_definition=field_definition,
727739
field_selection=field_selection,
728740
model_field=model_field,
729741
model_fieldname=model_fieldname,
@@ -772,7 +784,6 @@ def _get_hints_from_django_foreign_key(
772784

773785
def _get_hints_from_django_relation(
774786
field: StrawberryField,
775-
field_definition: GraphQLObjectType,
776787
field_selection: FieldNode,
777788
model_field: (
778789
models.ManyToManyField
@@ -961,7 +972,6 @@ def _get_hints_from_django_field(
961972
elif isinstance(model_field, relation_fields):
962973
store = _get_hints_from_django_relation(
963974
field,
964-
field_definition=field_definition,
965975
field_selection=field_selection,
966976
model_field=model_field,
967977
model_fieldname=model_fieldname,

tests/polymorphism/models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
class Company(models.Model):
66
name = models.CharField(max_length=100)
7+
main_project = models.ForeignKey('Project', on_delete=models.CASCADE, null=True)
8+
9+
class Meta:
10+
ordering = ('name',)
711

812

913
class Project(PolymorphicModel):

tests/polymorphism/schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class ResearchProjectType(ProjectType):
2626
class CompanyType:
2727
name: strawberry.auto
2828
projects: list[ProjectType]
29+
main_project: ProjectType | None
2930

3031

3132
@strawberry.type

tests/polymorphism/test_optimizer.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,61 @@ def test_polymorphic_offset_paginated_query():
164164
}
165165

166166

167+
@pytest.mark.django_db(transaction=True)
168+
def test_polymorphic_relation():
169+
ap = ArtProject.objects.create(topic="Art", artist="Artist")
170+
art_company = Company.objects.create(name="ArtCompany", main_project=ap)
171+
172+
rp = ResearchProject.objects.create(
173+
topic="Research", supervisor="Supervisor"
174+
)
175+
research_company = Company.objects.create(name="ResearchCompany", main_project=rp)
176+
177+
query = """\
178+
query {
179+
companies {
180+
name
181+
mainProject {
182+
__typename
183+
topic
184+
... on ArtProjectType {
185+
artist
186+
}
187+
... on ResearchProjectType {
188+
supervisor
189+
}
190+
}
191+
}
192+
}
193+
"""
194+
195+
# Company, ContentType, base table, two subtables = 5 queries
196+
with assert_num_queries(1):
197+
result = schema.execute_sync(query)
198+
assert not result.errors
199+
assert result.data == {
200+
"companies": [
201+
{
202+
"name": art_company.name,
203+
"mainProject": {
204+
"__typename": "ArtProjectType",
205+
"topic": ap.topic,
206+
"artist": ap.artist,
207+
},
208+
},
209+
{
210+
"name": research_company.name,
211+
"mainProject": {
212+
"__typename": "ResearchProjectType",
213+
"topic": rp.topic,
214+
"supervisor": rp.supervisor,
215+
},
216+
}
217+
]
218+
}
219+
220+
221+
167222
@pytest.mark.django_db(transaction=True)
168223
def test_polymorphic_nested_list():
169224
company = Company.objects.create(name="Company")

0 commit comments

Comments
 (0)